diff --git a/.gitignore b/.gitignore index 3c8fa18..9652fdc 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,7 @@ dist/ .vscode .DS_Store .dev/ + +# UI build output; placeholder.txt is kept so the go:embed succeeds. +internal/server/ui_dist/* +!internal/server/ui_dist/placeholder.txt diff --git a/AGENTS.md b/AGENTS.md index 13e1f9c..ec7e177 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -21,7 +21,8 @@ llama-swap is a light weight, transparent proxy server that provides automatic m - Follow test naming conventions like `TestProxyManager_`, `TestProcessGroup_`, etc. - Use `go test -v -run ` to run any new tests you've written. -- Run `gofmt -l .` before committing to verify formatting. Fix any reported files with `gofmt -w `. +- Run `gofmt -w ` before committing to fix any formatting +- Build go binaries into the ./build/ subdirectory - Use `make test-dev` after running new tests for a quick over all test run. This runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory - Use `make test-all` before completing work. This includes long running concurrency tests. - Use `make test-ui` after making changes to the UI in ui-svelte/ diff --git a/Makefile b/Makefile index 6a97cf8..13cd9e1 100644 --- a/Makefile +++ b/Makefile @@ -41,6 +41,8 @@ ui/node_modules: # build react UI ui: ui/node_modules cd ui-svelte && npm run build + mkdir -p internal/server/ui_dist + cp -R proxy/ui_dist/. internal/server/ui_dist/ # Build OSX binary mac: ui diff --git a/cmd/fake-model/main.go b/cmd/fake-model/main.go new file mode 100644 index 0000000..6f68326 --- /dev/null +++ b/cmd/fake-model/main.go @@ -0,0 +1,306 @@ +package main + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "io" + "log" + "net/http" + "os" + "os/signal" + "strings" + "syscall" + "time" + + "github.com/tidwall/gjson" +) + +var loremWords = strings.Fields( + "Lorem ipsum dolor sit amet consectetur adipiscing elit sed do eiusmod tempor " + + "incididunt ut labore et dolore magna aliqua Ut enim ad minim veniam quis nostrud " + + "exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat Duis aute " + + "irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla " + + "pariatur Excepteur sint occaecat cupidatat non proident sunt in culpa qui officia " + + "deserunt mollit anim id est laborum Sed ut perspiciatis unde omnis iste natus error " + + "sit voluptatem accusantium doloremque laudantium totam rem aperiam eaque ipsa quae " + + "ab illo inventore veritatis et quasi architecto beatae vitae dicta sunt explicabo " + + "Nemo enim ipsam voluptatem quia voluptas sit aspernatur aut odit aut fugit", +) + +var ( + flagListen = flag.String("listen", "localhost:9898", "listen address") + flagTokens = flag.Int("tokens", 1000, "number of tokens to return") + flagTPS = flag.Float64("tps", 75, "tokens per second") + flagLoad = flag.String("load", "0s", "simulated load duration (e.g. 2s, 500ms)") +) + +type chunkDelta struct { + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` +} + +type chunkChoice struct { + Index int `json:"index"` + Delta chunkDelta `json:"delta"` + FinishReason *string `json:"finish_reason"` +} + +type chatChunk struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []chunkChoice `json:"choices"` +} + +type completionMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type completionChoice struct { + Index int `json:"index"` + Message completionMessage `json:"message"` + FinishReason string `json:"finish_reason"` +} + +type completionUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type chatCompletion struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []completionChoice `json:"choices"` + Usage completionUsage `json:"usage"` +} + +func loremText(n int) string { + words := make([]string, n) + for i := range words { + words[i] = loremWords[i%len(loremWords)] + } + return strings.Join(words, " ") +} + +func sendChunk(w http.ResponseWriter, content string, finishReason *string) error { + chunk := chatChunk{ + ID: "chatcmpl-fake", + Object: "chat.completion.chunk", + Created: time.Now().Unix(), + Model: "fake-model", + Choices: []chunkChoice{ + { + Index: 0, + Delta: chunkDelta{Content: content}, + FinishReason: finishReason, + }, + }, + } + data, err := json.Marshal(chunk) + if err != nil { + return err + } + _, err = fmt.Fprintf(w, "data: %s\n\n", data) + return err +} + +// startLoading runs the countdown log and closes ready when loadDur elapses. +// If loadDur is zero, ready is closed immediately. +func startLoading(loadDur time.Duration) <-chan struct{} { + ready := make(chan struct{}) + if loadDur == 0 { + close(ready) + return ready + } + go func() { + deadline := time.Now().Add(loadDur) + log.Printf("loading... %s remaining", loadDur.Round(time.Second)) + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + timer := time.NewTimer(loadDur) + for { + select { + case <-timer.C: + close(ready) + log.Printf("ready") + return + case <-ticker.C: + if rem := time.Until(deadline).Round(time.Second); rem > 0 { + log.Printf("loading... %s remaining", rem) + } + } + } + }() + return ready +} + +func healthHandler(ready <-chan struct{}) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + select { + case <-ready: + w.WriteHeader(http.StatusOK) + default: + w.WriteHeader(http.StatusServiceUnavailable) + } + } +} + +func chatHandler(ready <-chan struct{}) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "failed to read body", http.StatusBadRequest) + return + } + + streaming := gjson.GetBytes(body, "stream").Bool() + ctx := r.Context() + + select { + case <-ready: + case <-ctx.Done(): + return + } + + tokens := *flagTokens + tps := *flagTPS + if tps <= 0 { + tps = 1 + } + + if !streaming { + delay := time.Duration(float64(tokens) / tps * float64(time.Second)) + select { + case <-time.After(delay): + case <-ctx.Done(): + return + } + text := loremText(tokens) + resp := chatCompletion{ + ID: "chatcmpl-fake", + Object: "chat.completion", + Created: time.Now().Unix(), + Model: "fake-model", + Choices: []completionChoice{ + { + Index: 0, + Message: completionMessage{Role: "assistant", Content: text}, + FinishReason: "stop", + }, + }, + Usage: completionUsage{ + PromptTokens: 0, + CompletionTokens: tokens, + TotalTokens: tokens, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + return + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "streaming not supported", http.StatusInternalServerError) + return + } + + // Send role delta first + first := chatChunk{ + ID: "chatcmpl-fake", + Object: "chat.completion.chunk", + Created: time.Now().Unix(), + Model: "fake-model", + Choices: []chunkChoice{ + {Index: 0, Delta: chunkDelta{Role: "assistant"}}, + }, + } + if data, err := json.Marshal(first); err == nil { + fmt.Fprintf(w, "data: %s\n\n", data) + flusher.Flush() + } + + interval := time.Duration(float64(time.Second) / tps) + ticker := time.NewTicker(interval) + defer ticker.Stop() + + stop := "stop" + for i := 0; i < tokens; i++ { + select { + case <-ctx.Done(): + return + case <-ticker.C: + } + + word := loremWords[i%len(loremWords)] + if i < tokens-1 { + if err := sendChunk(w, word+" ", nil); err != nil { + return + } + } else { + if err := sendChunk(w, word, &stop); err != nil { + return + } + } + flusher.Flush() + } + + fmt.Fprintf(w, "data: [DONE]\n\n") + flusher.Flush() + } +} + +func main() { + flag.Parse() + + loadDur, err := time.ParseDuration(*flagLoad) + if err != nil { + log.Fatalf("invalid -load value %q: %v", *flagLoad, err) + } + + ready := startLoading(loadDur) + + mux := http.NewServeMux() + mux.HandleFunc("/health", healthHandler(ready)) + mux.HandleFunc("/v1/chat/completions", chatHandler(ready)) + + srv := &http.Server{ + Addr: *flagListen, + Handler: mux, + } + + go func() { + log.Printf("listening on %s (tokens=%d tps=%.1f load=%s)", + *flagListen, *flagTokens, *flagTPS, loadDur) + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Fatalf("server error: %v", err) + } + }() + + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + <-quit + + log.Println("shutting down...") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := srv.Shutdown(ctx); err != nil { + log.Printf("shutdown error: %v", err) + } +} diff --git a/cmd/legacy/llama-swap.go b/cmd/legacy/llama-swap.go new file mode 100644 index 0000000..2b2d26f --- /dev/null +++ b/cmd/legacy/llama-swap.go @@ -0,0 +1,249 @@ +package main + +import ( + "context" + "flag" + "fmt" + "net/http" + "os" + "os/signal" + "path/filepath" + "strings" + "sync" + "syscall" + "time" + + "github.com/gin-gonic/gin" + "github.com/mostlygeek/llama-swap/internal/config" + "github.com/mostlygeek/llama-swap/internal/event" + "github.com/mostlygeek/llama-swap/internal/logmon" + "github.com/mostlygeek/llama-swap/internal/perf" + "github.com/mostlygeek/llama-swap/internal/watcher" + "github.com/mostlygeek/llama-swap/proxy" +) + +var ( + version string = "0" + commit string = "abcd1234" + date string = "unknown" +) + +func main() { + // Define a command-line flag for the port + configPath := flag.String("config", "config.yaml", "config file name") + listenStr := flag.String("listen", "", "listen ip/port") + certFile := flag.String("tls-cert-file", "", "TLS certificate file") + keyFile := flag.String("tls-key-file", "", "TLS key file") + showVersion := flag.Bool("version", false, "show version of build") + watchConfig := flag.Bool("watch-config", false, "Automatically reload config file on change") + mainLogger := logmon.New() + + flag.Parse() // Parse the command-line flags + + if *showVersion { + fmt.Printf("version: %s (%s), built at %s", version, commit, date) + os.Exit(0) + } + + conf, err := config.LoadConfig(*configPath) + if err != nil { + mainLogger.Errorf("Error loading config: %v", err) + os.Exit(1) + } + + if len(conf.Profiles) > 0 { + mainLogger.Warn("Profile functionality has been removed in favor of Groups. See the README for more information.") + } + + switch strings.ToLower(strings.TrimSpace(conf.LogLevel)) { + case "debug": + mainLogger.SetLogLevel(logmon.LevelDebug) + case "info": + mainLogger.SetLogLevel(logmon.LevelInfo) + case "warn": + mainLogger.SetLogLevel(logmon.LevelWarn) + case "error": + mainLogger.SetLogLevel(logmon.LevelError) + default: + mainLogger.SetLogLevel(logmon.LevelInfo) + } + + mainLogger.Debugf("PID: %d", os.Getpid()) + + if mode := os.Getenv("GIN_MODE"); mode != "" { + gin.SetMode(mode) + } else { + gin.SetMode(gin.ReleaseMode) + } + + // Validate TLS flags. + var useTLS = (*certFile != "" && *keyFile != "") + if (*certFile != "" && *keyFile == "") || + (*certFile == "" && *keyFile != "") { + fmt.Println("Error: Both --tls-cert-file and --tls-key-file must be provided for TLS.") + os.Exit(1) + } + + // Set default ports. + if *listenStr == "" { + defaultPort := ":8080" + if useTLS { + defaultPort = ":8443" + } + listenStr = &defaultPort + } + + var mon *perf.Monitor + if !conf.Performance.Disabled { + mon, err = perf.New(conf.Performance, mainLogger) + if err != nil { + mainLogger.Errorf("failed to create monitor: %s", err.Error()) + os.Exit(1) + } + mon.Start() + } else { + mainLogger.Info("performance monitoring is disabled") + } + + // Setup channels for server management + exitChan := make(chan struct{}) + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP) + + // Context that bounds the lifetime of background watcher goroutines. + watcherCtx, watcherCancel := context.WithCancel(context.Background()) + + // Create server with initial handlergit + srv := &http.Server{ + Addr: *listenStr, + } + + // Support for watching config and reloading when it changes + reloading := false + var reloadMutex sync.Mutex + reloadProxyManager := func() { + reloadMutex.Lock() + if reloading { + reloadMutex.Unlock() + return + } + reloading = true + reloadMutex.Unlock() + defer func() { + reloadMutex.Lock() + reloading = false + reloadMutex.Unlock() + }() + + if currentPM, ok := srv.Handler.(*proxy.ProxyManager); ok { + mainLogger.Info("Reloading Configuration") + conf, err = config.LoadConfig(*configPath) + if err != nil { + mainLogger.Warnf("Unable to reload configuration: %v", err) + return + } + + mainLogger.Debug("Configuration Changed") + currentPM.Shutdown() + if mon != nil { + mon.UpdateConfig(conf.Performance) + } + newPM := proxy.New(conf) + newPM.SetVersion(date, commit, version) + newPM.SetPerfMonitor(mon) + srv.Handler = newPM + mainLogger.Debug("Configuration Reloaded") + + // wait a few seconds and tell any UI to reload + time.AfterFunc(3*time.Second, func() { + event.Emit(proxy.ConfigFileChangedEvent{ + ReloadingState: proxy.ReloadingStateEnd, + }) + }) + } else { + conf, err = config.LoadConfig(*configPath) + if err != nil { + mainLogger.Errorf("Unable to load configuration: %v", err) + os.Exit(1) + } + newPM := proxy.New(conf) + newPM.SetVersion(date, commit, version) + newPM.SetPerfMonitor(mon) + srv.Handler = newPM + } + } + + // load the initial proxy manager + reloadProxyManager() + + if *watchConfig { + go func() { + absConfigPath, err := filepath.Abs(*configPath) + if err != nil { + mainLogger.Errorf("watch-config unable to determine absolute path for watching config file: %v", err) + return + } + mainLogger.Info("Watching configuration for changes (poll-based, 2s interval)") + (&configwatcher.Watcher{ + Path: absConfigPath, + Interval: configwatcher.DefaultInterval, + OnChange: func() { + reloadProxyManager() + }, + }).Run(watcherCtx) + }() + } + + // Signal handling + go func() { + for { + sig := <-sigChan + switch sig { + case syscall.SIGHUP: + mainLogger.Debug("Received SIGHUP") + reloadProxyManager() + case syscall.SIGINT, syscall.SIGTERM: + mainLogger.Debugf("Received signal %v, shutting down...", sig) + if mon != nil { + mon.Stop() + } + watcherCancel() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + if pm, ok := srv.Handler.(*proxy.ProxyManager); ok { + pm.Shutdown() + } else { + mainLogger.Error("srv.Handler is not of type *proxy.ProxyManager") + } + + if err := srv.Shutdown(ctx); err != nil { + mainLogger.Errorf("Server shutdown: %v", err) + } + close(exitChan) + return + default: + // do nothing on other signals + } + } + }() + + // Start server + go func() { + var err error + if useTLS { + mainLogger.Infof("llama-swap listening with TLS on https://%s", *listenStr) + err = srv.ListenAndServeTLS(*certFile, *keyFile) + } else { + mainLogger.Infof("llama-swap listening on http://%s", *listenStr) + err = srv.ListenAndServe() + } + if err != nil && err != http.ErrServerClosed { + mainLogger.Errorf("Fatal server error: %v", err) + os.Exit(1) + } + }() + + // Wait for exit signal + <-exitChan +} diff --git a/cmd/monitor-test/main.go b/cmd/monitor-test/main.go index 6965d72..2e91148 100644 --- a/cmd/monitor-test/main.go +++ b/cmd/monitor-test/main.go @@ -8,9 +8,9 @@ import ( "strings" "time" + "github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/perf" - "github.com/mostlygeek/llama-swap/proxy/config" ) func printSysStat(s perf.SysStat) { diff --git a/cmd/test-concurrency/main.go b/cmd/test-concurrency/main.go new file mode 100644 index 0000000..a0f5e4f --- /dev/null +++ b/cmd/test-concurrency/main.go @@ -0,0 +1,96 @@ +package main + +import ( + "flag" + "fmt" + "os" + "sync" + "time" + + tea "github.com/charmbracelet/bubbletea" +) + +func main() { + prompt := flag.String("prompt", "Write a few sentences about the history of computing.", "user message sent to each model") + maxTokens := flag.Int("max-tokens", 256, "max_tokens per request") + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "Usage: %s [flags] [model...]\n", os.Args[0]) + fmt.Fprintf(os.Stderr, "Example: %s -max-tokens 400 http://localhost:8080 A B C D\n\n", os.Args[0]) + flag.PrintDefaults() + } + flag.Parse() + + args := flag.Args() + if len(args) < 2 { + flag.Usage() + os.Exit(1) + } + + baseURL := args[0] + models := args[1:] + + m := newModel(models) + prog := tea.NewProgram(m, tea.WithAltScreen(), tea.WithMouseCellMotion()) + + // Chain of triggers ensures requests are sent in the order provided. + triggers := make([]chan struct{}, len(models)) + for i := range triggers { + triggers[i] = make(chan struct{}, 1) + } + triggers[0] <- struct{}{} + + var wg sync.WaitGroup + start := time.Now() + + for i, name := range models { + wg.Add(1) + go func(idx int, mdl string) { + defer wg.Done() + + <-triggers[idx] + + reqStart := time.Now() + prog.Send(statusMsg{idx: idx, status: statusStreaming}) + + if idx+1 < len(triggers) { + triggers[idx+1] <- struct{}{} + } + + err := sendRequest(baseURL, mdl, *prompt, *maxTokens, idx, func(i int, text string) { + prog.Send(deltaMsg{idx: i, text: text}) + }) + + elapsed := time.Since(reqStart) + if err != nil { + prog.Send(statusMsg{idx: idx, status: statusError, elapsed: elapsed, err: err}) + } else { + prog.Send(statusMsg{idx: idx, status: statusDone, elapsed: elapsed}) + } + }(i, name) + } + + if _, err := prog.Run(); err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + + wg.Wait() + printSummary(m, start) +} + +func printSummary(m *model, start time.Time) { + fmt.Println("Summary:") + for _, p := range m.panels { + switch p.status { + case statusError: + fmt.Printf(" [%d] %-20s ERROR elapsed=%s err=%v\n", + p.idx, p.model, p.elapsed.Round(time.Millisecond), p.err) + case statusDone: + fmt.Printf(" [%d] %-20s done elapsed=%s\n", + p.idx, p.model, p.elapsed.Round(time.Millisecond)) + default: + fmt.Printf(" [%d] %-20s %s\n", p.idx, p.model, p.status) + } + } + fmt.Printf("all done in %s\n", time.Since(start).Round(time.Millisecond)) +} diff --git a/cmd/test-concurrency/request.go b/cmd/test-concurrency/request.go new file mode 100644 index 0000000..67ac5cb --- /dev/null +++ b/cmd/test-concurrency/request.go @@ -0,0 +1,88 @@ +package main + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" +) + +// deltaSink receives streamed text fragments for a given model panel. +type deltaSink func(idx int, text string) + +type streamDelta struct { + Content string `json:"content"` + ReasoningContent string `json:"reasoning_content"` +} + +type streamChoice struct { + Delta streamDelta `json:"delta"` +} + +type streamChunk struct { + Choices []streamChoice `json:"choices"` +} + +// sendRequest streams a chat completion and forwards each content/reasoning +// delta to sink. Reasoning and assistant content are emitted into the same +// stream so they render together. +func sendRequest(baseURL, model, prompt string, maxTokens, idx int, sink deltaSink) error { + payload := map[string]any{ + "model": model, + "messages": []map[string]string{ + {"role": "user", "content": prompt}, + }, + "max_tokens": maxTokens, + "stream": true, + } + + body, err := json.Marshal(payload) + if err != nil { + return err + } + + resp, err := http.Post(baseURL+"/v1/chat/completions", "application/json", bytes.NewReader(body)) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(resp.Body) + return fmt.Errorf("status %d: %s", resp.StatusCode, strings.TrimSpace(string(b))) + } + + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data:") { + continue + } + data := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + if data == "" || data == "[DONE]" { + if data == "[DONE]" { + break + } + continue + } + + var chunk streamChunk + if err := json.Unmarshal([]byte(data), &chunk); err != nil { + continue + } + for _, c := range chunk.Choices { + if c.Delta.ReasoningContent != "" { + sink(idx, c.Delta.ReasoningContent) + } + if c.Delta.Content != "" { + sink(idx, c.Delta.Content) + } + } + } + + return scanner.Err() +} diff --git a/cmd/test-concurrency/tui.go b/cmd/test-concurrency/tui.go new file mode 100644 index 0000000..9d838bf --- /dev/null +++ b/cmd/test-concurrency/tui.go @@ -0,0 +1,343 @@ +package main + +import ( + "fmt" + "strings" + "time" + + "github.com/charmbracelet/bubbles/viewport" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" +) + +type panelStatus int + +const ( + statusWaiting panelStatus = iota + statusStreaming + statusDone + statusError +) + +func (s panelStatus) String() string { + switch s { + case statusStreaming: + return "streaming" + case statusDone: + return "done" + case statusError: + return "error" + default: + return "waiting" + } +} + +// deltaMsg appends streamed text to a panel. +type deltaMsg struct { + idx int + text string +} + +// statusMsg updates a panel's lifecycle state. +type statusMsg struct { + idx int + status panelStatus + elapsed time.Duration + err error +} + +type panel struct { + idx int + model string + color lipgloss.Color + status panelStatus + buf strings.Builder + elapsed time.Duration + err error +} + +const ( + minPanelWidth = 28 + maxCols = 3 + panelHeight = 9 // total box height including border + header +) + +type model struct { + panels []*panel + focused int + vp viewport.Model + width int + height int + cols int + pw int // inner panel content width + ready bool +} + +func newModel(models []string) *model { + // Assign a stable color per unique model name (by first appearance). + colorOf := map[string]lipgloss.Color{} + panels := make([]*panel, len(models)) + for i, m := range models { + c, ok := colorOf[m] + if !ok { + c = modelPalette[len(colorOf)%len(modelPalette)] + colorOf[m] = c + } + panels[i] = &panel{idx: i, model: m, color: c, status: statusWaiting} + } + return &model{panels: panels, focused: 0} +} + +func (m *model) Init() tea.Cmd { return nil } + +func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.WindowSizeMsg: + m.width = msg.Width + m.height = msg.Height + m.relayout() + m.refreshViewport(true) + return m, nil + + case tea.KeyMsg: + switch msg.String() { + case "q", "ctrl+c", "esc": + return m, tea.Quit + case "tab", "right", "l": + m.setFocus(m.focused + 1) + return m, nil + case "shift+tab", "left", "h": + m.setFocus(m.focused - 1) + return m, nil + } + var cmd tea.Cmd + m.vp, cmd = m.vp.Update(msg) + return m, cmd + + case tea.MouseMsg: + if msg.Action == tea.MouseActionPress && msg.Button == tea.MouseButtonLeft { + if idx, ok := m.panelAt(msg.X, msg.Y); ok { + m.setFocus(idx) + } + return m, nil + } + var cmd tea.Cmd + m.vp, cmd = m.vp.Update(msg) + return m, cmd + + case deltaMsg: + p := m.panels[msg.idx] + p.buf.WriteString(msg.text) + if msg.idx == m.focused { + atBottom := m.vp.AtBottom() + m.refreshViewport(false) + if atBottom { + m.vp.GotoBottom() + } + } + return m, nil + + case statusMsg: + p := m.panels[msg.idx] + p.status = msg.status + p.elapsed = msg.elapsed + p.err = msg.err + if msg.err != nil { + errTxt := lipgloss.NewStyle().Foreground(lipgloss.Color("196")).Render("\n" + msg.err.Error()) + p.buf.WriteString(errTxt) + if msg.idx == m.focused { + m.refreshViewport(false) + m.vp.GotoBottom() + } + } + return m, nil + } + + return m, nil +} + +func (m *model) setFocus(idx int) { + if len(m.panels) == 0 { + return + } + if idx < 0 { + idx = len(m.panels) - 1 + } + if idx >= len(m.panels) { + idx = 0 + } + if idx == m.focused { + return + } + m.focused = idx + m.refreshViewport(true) +} + +// relayout recomputes grid columns and panel/viewport dimensions. +func (m *model) relayout() { + if m.width < minPanelWidth+4 { + m.cols = 1 + } else { + m.cols = m.width / (minPanelWidth + 2) + if m.cols > maxCols { + m.cols = maxCols + } + if m.cols > len(m.panels) { + m.cols = len(m.panels) + } + if m.cols < 1 { + m.cols = 1 + } + } + + // inner content width: total width / cols, minus borders+padding (4) and gap. + boxOuter := m.width/m.cols - 1 + m.pw = boxOuter - 4 + if m.pw < 8 { + m.pw = 8 + } + + m.vp = viewport.New(m.pw, panelHeight-2) + m.ready = true +} + +func (m *model) refreshViewport(reset bool) { + if !m.ready || len(m.panels) == 0 { + return + } + content := lipgloss.NewStyle().Width(m.pw).Render(m.panels[m.focused].buf.String()) + m.vp.SetContent(content) + if reset { + m.vp.GotoBottom() + } +} + +// panelAt maps screen coordinates to a panel index based on the grid layout. +func (m *model) panelAt(x, y int) (int, bool) { + if m.cols == 0 { + return 0, false + } + boxOuterW := m.width/m.cols + 1 + col := x / boxOuterW + row := y / panelHeight + idx := row*m.cols + col + if col < m.cols && idx >= 0 && idx < len(m.panels) { + return idx, true + } + return 0, false +} + +func (m *model) View() string { + if !m.ready { + return "loading..." + } + + rows := []string{} + var current []string + for i, p := range m.panels { + current = append(current, m.renderPanel(p, i == m.focused)) + if len(current) == m.cols { + rows = append(rows, lipgloss.JoinHorizontal(lipgloss.Top, current...)) + current = nil + } + } + if len(current) > 0 { + rows = append(rows, lipgloss.JoinHorizontal(lipgloss.Top, current...)) + } + + grid := lipgloss.JoinVertical(lipgloss.Left, rows...) + footer := lipgloss.NewStyle().Faint(true).Render( + "tab/click: focus panel • wheel/↑↓/pgup/pgdn: scroll focused • q: quit") + return grid + "\n" + footer +} + +// modelPalette gives each panel a distinct, readable color for its name. +var modelPalette = []lipgloss.Color{ + "39", // blue + "213", // magenta + "214", // orange + "45", // cyan + "141", // purple + "203", // salmon + "82", // lime + "227", // light yellow +} + +func statusColor(s panelStatus) lipgloss.Color { + switch s { + case statusStreaming: + return lipgloss.Color("220") // yellow - active + case statusDone: + return lipgloss.Color("42") // green - success + case statusError: + return lipgloss.Color("196") // red - error + default: + return lipgloss.Color("244") // gray - waiting + } +} + +func (m *model) renderPanel(p *panel, focused bool) string { + border := lipgloss.RoundedBorder() + if focused { + border = lipgloss.DoubleBorder() + } + style := lipgloss.NewStyle(). + Border(border). + BorderForeground(lipgloss.Color("240")) + + statusTxt := p.status.String() + if p.elapsed > 0 { + statusTxt += " " + p.elapsed.Round(time.Millisecond).String() + } + + // Header: model name (left, model color) + status/timer (right, status color). + name := fmt.Sprintf("[%d] %s", p.idx, p.model) + gap := m.pw - lipgloss.Width(name) - lipgloss.Width(statusTxt) + if gap < 1 { + name = truncate(name, m.pw-lipgloss.Width(statusTxt)-1) + gap = m.pw - lipgloss.Width(name) - lipgloss.Width(statusTxt) + } + if gap < 1 { + gap = 1 + } + header := lipgloss.NewStyle().Bold(true).Foreground(p.color).Render(name) + + strings.Repeat(" ", gap) + + lipgloss.NewStyle().Foreground(statusColor(p.status)).Render(statusTxt) + + var bodyLines string + if focused { + bodyLines = m.vp.View() + } else { + bodyLines = tailLines(p.buf.String(), m.pw, panelHeight-2) + } + + content := lipgloss.JoinVertical(lipgloss.Left, header, bodyLines) + return style.Width(m.pw).Height(panelHeight - 2).Render(content) +} + +func truncate(s string, w int) string { + if w <= 0 { + return "" + } + if lipgloss.Width(s) <= w { + return s + } + r := []rune(s) + if len(r) > w { + r = r[:w] + } + return string(r) +} + +// tailLines wraps text to width w and returns the last n lines. +func tailLines(s string, w, n int) string { + wrapped := lipgloss.NewStyle().Width(w).Render(s) + lines := strings.Split(wrapped, "\n") + if len(lines) > n { + lines = lines[len(lines)-n:] + } + for len(lines) < n { + lines = append(lines, "") + } + return strings.Join(lines, "\n") +} diff --git a/docs/newrouter-todo.md b/docs/newrouter-todo.md new file mode 100644 index 0000000..46a596b --- /dev/null +++ b/docs/newrouter-todo.md @@ -0,0 +1,264 @@ +# New Router Migration TODO + +This document tracks the work needed for [cmd/newrouter/main.go](../cmd/newrouter/main.go) and [internal/router/](../internal/router/) to reach feature parity with the legacy entrypoint at [llama-swap.go](../llama-swap.go) plus [proxy/proxymanager.go](../proxy/proxymanager.go). + +The work is split into phases so each can land and be tested independently. Earlier phases unblock later ones. + +## Current state (newrouter) + +`cmd/newrouter` already supports: + +- Loading config via `-config` +- Selecting Matrix vs Group router based on config +- Peer routing fallback +- Plain HTTP listen (`-listen`) +- Graceful shutdown on `SIGINT` / `SIGTERM` +- Model extraction from JSON body, query string, and form bodies (see [router.go:88](../internal/router/router.go#L88)) +- `Server.ServeHTTP` dispatches a single request to peer or local router based on the requested model + +Everything below is missing or only partially implemented. + +--- + +## Phase 1 — Package relocation -- Completed. + +Goal: move shared infrastructure packages out from under `proxy/` so the new router does not depend on the legacy proxy tree. This is a prerequisite for retiring `proxy/` in Phase 8. + +--- + +## Phase 2 — Server lifecycle parity -- Completed. + +Goal: make `cmd/newrouter` a drop-in replacement for the legacy binary's process model, _without_ yet adding any extra HTTP endpoints. + +--- + +## Phase 3 — `internal/chain` package -- Completed. + +API: `chain.New(mws...).Then(final)` for ServeMux registration; `Append` returns an extended Chain without mutating the receiver, so a base stack (auth/CORS) can be reused across many routes with per-route additions. + +--- + +## Phase 4 — `internal/server` package scaffolding (ProxyManager replacement) -- Completed. + +Goal: build the [internal/server](../internal/server/) package so it can stand in for [proxy.ProxyManager](../proxy/proxymanager.go#L67) — the mux, lifecycle, model dispatch, custom endpoints, request filters, auth/CORS, and upstream passthrough. After this phase, `cmd/newrouter/main.go` constructs a `server.Server` instead of a bare `router.Server`. + +The legacy `ProxyManager` collapses three concerns into one struct: the HTTP mux, the model→process router, and the cross-cutting services (loggers, metrics, perf, inflight counter, version). The new layout keeps the `router.Router` implementations focused on model dispatch and lets `internal/server.Server` own the mux and all cross-cutting middleware. `server.Server` builds the `local` and `peer` routers directly and dispatches between them itself, so it fully **supersedes `internal/router.Server`** — see the cleanup item below. + +The phase is split into sub-phases that can land and be tested independently: + +| Sub-phase | Scope | +| --------- | -------------------------------------------------------------------------- | +| 4a | package scaffolding — struct, `New`, `ServeHTTP`, `Shutdown`, model routes | +| 4b | custom (non-model-dispatched) HTTP endpoints | +| 4c | request-body filter middleware | +| 4d | auth & CORS middleware | +| 4e | upstream passthrough | + +The package is split by concern across stub files already in place: + +| File | Responsibility | Filled in by | +| ------------ | ----------------------------------------------- | ---------------------- | +| `server.go` | `Server` struct, `New`, `ServeHTTP`, `Shutdown` | 4a | +| `log.go` | `muxlog` combined logger; `/logs` handlers | 4a | +| `auth.go` | `CreateAuthMiddleware` | 4d | +| `filters.go` | request-body filter middleware | 4c | +| `api.go` | llama-swap-specific API handlers | 4b / Phase 5 / Phase 6 | +| `ui.go` | embedded UI serving | Phase 7 | + +### Phase 4a — package scaffolding -- Completed. + +`server.Server` owns the mux, the `local`/`peer` routers, `muxlog`, and a +shutdown context. `New` builds the routers, registers all model-dispatched +routes on a stdlib `http.ServeMux`, and wraps the mux with the global CORS +middleware. `localPeerHandler` resolves the model once via `router.FetchModel` +and dispatches to `local` or `peer`. `Shutdown` stops both routers in parallel +and is idempotent. `cmd/newrouter/main.go` now constructs `server.New(...)`; +`internal/router/server.go` and `server_test.go` were removed as dead code. + +### Phase 4b — Custom HTTP endpoints -- Completed. + +`GET /v1/models` (local + peer models, aliases, metadata), `GET /health`, +`GET /wol-health`, and `GET /` → `/ui` are registered. `GET /favicon.ico` is +deferred to Phase 7 since it requires the embedded UI filesystem. + +### Phase 4c — Request-body filters -- Completed. + +`CreateFilterMiddleware` (in `filters.go`) applies `UseModelName`, +`StripParams`, `SetParams`, and `SetParamsByID` to JSON requests, then +re-attaches the body with `Content-Length` / `Transfer-Encoding` cleanup. + +### Phase 4d — Auth & CORS -- Completed. + +`CreateAuthMiddleware` validates API keys (Bearer / Basic / `x-api-key`) and +strips the headers before upstream. `CreateCORSMiddleware` answers OPTIONS +preflight; `/v1/models` echoes the `Origin`. + +### Phase 4e — Upstream passthrough -- Completed. + +`GET /upstream` → `/ui/models`, and `/upstream//` proxies to the +resolved model with multi-segment name resolution, canonical-form redirect +(301/308), and prefix stripping. + +--- + +## Phase 5 — Operations endpoints -- Completed. + +A new `router.LocalRouter` interface embeds `Router` and adds `RunningModels()` +and `Unload(timeout, models...)`, both implemented once on `baseRouter` so +`Group` and `Matrix` share them — the legacy matrix/group divergence at +[proxymanager.go:1167](../proxy/proxymanager.go#L1167) collapses since +`baseRouter` already unifies process storage. `Peer` does not implement it; +`Server.local` is typed `LocalRouter`, `Server.peer` stays `Router`. + +`GET /unload` stops every local process; `GET /running` lists non-stopped +processes joined against config for `cmd`/`proxy`/`ttl`/`name`/`description`. +`startPreload` fires a background `GET /` at each `Hooks.OnStartup.Preload` +model and emits `shared.ModelPreloadedEvent`. + +--- + +## Phase 6 — Metrics, perf, and SSE -- Completed. + +`perf.Monitor` is created and started in `cmd/newrouter/main.go` (it outlives +config reloads via `UpdateConfig`) and passed into `server.New`. `GET /metrics` +serves `perf.Monitor.MetricsHandler()` output, 503 when disabled. + +`internal/process` emits `shared.ProcessStateChangeEvent` from `setState`. +`server.inflightCounter` (atomic) + `CreateInflightMiddleware` track +model-dispatched requests and emit `InFlightRequestsEvent`. `metricsMonitor` +(in `metrics.go`) parses token usage from upstream responses via +`CreateMetricsMiddleware`. + +The `/api` group (API-key protected) is registered: `POST /api/models/unload`, +`POST /api/models/unload/{model...}`, `GET /api/events` (SSE: `modelStatus` / +`logData` / `metrics` / `inflight`), `GET /api/metrics`, `GET /api/performance` +(`?after=` RFC3339 filter), `GET /api/version`. `GET /api/captures/{id}` +returns 501 until 6f. + +### Phase 6f — Request/response captures -- Completed. + +`proxy/cache` moved to `internal/cache`. `metricsMonitor` stores zstd+CBOR +`ReqRespCapture` records in a sized `cache.Cache` (`captureBuffer` MB, 0 +disables). `CreateMetricsMiddleware` buffers request body/headers before +dispatch; `record` builds the capture per a `captureFieldsByPath` table +(`captures.go`) that trims large audio/image payloads, defaulting JSON routes +to `captureAll`. `GET /api/captures/{id}` decompresses and returns the capture; +`getMetrics` resolves `HasCapture` against the cache. + +--- + +## Phase 7 — UI serving -- Completed. + +`internal/server/ui.go` embeds `ui_dist` and serves it. `GET /ui/` is +brotli/gzip-aware via `serveCompressedFile`; unknown paths without a file +extension fall back to `index.html` for SPA routing. `GET /favicon.ico` serves +from the same embedded FS. The Makefile `ui` target copies the vite build into +`internal/server/ui_dist`; a committed `placeholder.txt` keeps the embed valid +before a build runs. + +--- + +## Phase 8a - Review Part I + +- [x] All functionality from the proxy package has been migrated in the above phases — with the remaining gaps listed in Phase 8b +- [x] Test coverage at or exceeds the level from the proxy package — `internal/server` now at 76.6% vs 73.9% (`proxy`) + +### Findings + +**Gap 1 — Request logging middleware missing -- Resolved.** + +`CreateRequestLogMiddleware` ([log.go](../internal/server/log.go)) records one +access-log line per request to `s.proxylog` in the legacy format +`clientIP "METHOD PATH PROTO" status bodySize "UA" duration`, skipping +`/wol-health`, `/api/performance`, and `/metrics`. A `statusRecorder` captures +the status/body size (forwarding `Flush` for SSE) and `clientIP` honours +`X-Forwarded-For` / `X-Real-IP`. It is wired as the outermost middleware in +`routes()`, wrapping the CORS layer. + +**Gap 2 — Per-model log streaming not supported -- Resolved ** + +`Server.getLogger` ([log.go:50](../internal/server/log.go#L50)) only handles `""`, `"proxy"`, and `"upstream"`. The legacy `ProxyManager.getLogger` ([proxymanager_loghandlers.go:92](../proxy/proxymanager_loghandlers.go#L92)) additionally resolves a model ID against the active process groups / matrix and returns that process's logger. Callers of `GET /logs/stream/` will get a 400 instead of the model's live log stream. + +**Gap 3 — `UseModelName` not applied to multipart form endpoints -- Resolved.** + +`CreateFormFilterMiddleware` ([filters.go](../internal/server/filters.go)) parses +`multipart/form-data` requests, rewrites the `model` field with `UseModelName`, +reconstructs the body via `rewriteMultipartModel`, and re-attaches it with +`Content-Type` / `Content-Length` cleanup. It runs in `modelChain` after the +JSON `filterMW`; each is a no-op for the other's Content-Type. Audio +transcription (`/v1/audio/transcriptions`) and image edit (`/v1/images/edits`) +now honour `use_model_name`. + +**Coverage gaps (0 % functions) -- Resolved.** + +The functions previously at 0 % (`handleListModels`, `handleMetrics`, +`handleRootRedirect`, `handleUpstreamRedirect`, `handleUpstream`, +`findModelInPath`, `handleAPICapture`, `handleAPIUnloadAll`, +`handleAPIUnloadModel`, `CreateAuthMiddleware`, `extractAPIKey`, +`handleLogStream`, `applyFilters`, `decompressBody`, `filterAcceptEncoding`, +`handleUI`, `handleFavicon`) now have tests across `auth_test.go`, `api_test.go`, +`filters_test.go`, `log_test.go`, and `extras_test.go`. + +--- + +### Phase 8b - Fill gaps discovered in Phase 8a + +- [x] **Add request-log middleware** — `CreateRequestLogMiddleware` ([log.go](../internal/server/log.go)) records `clientIP "METHOD PATH PROTO" status bodySize "UA" duration` to `s.proxylog`, skips `/wol-health` / `/api/performance` / `/metrics`, and is wired as the outermost middleware in `routes()`. +- [x] **Extend `getLogger` with model-ID resolution** — add a `default:` branch to `Server.getLogger` ([log.go:50](../internal/server/log.go#L50)) that resolves the ID via `s.local` (using a new `LocalRouter.GetProcess(name)` method or equivalent) and returns that process's `Logger()`. Match the fallback behaviour: return a 400 with `"invalid logger. Use 'proxy', 'upstream' or a model's ID"` when not found. +- [x] **`UseModelName` rewrite for multipart endpoints** — `CreateFormFilterMiddleware` parses `multipart/form-data`, rewrites the `model` field according to `UseModelName`, reconstructs the body, and updates `Content-Type` / `Content-Length`. It is wired into `modelChain` after the JSON filter. +- [x] **Raise test coverage to ≥ 74 %** — `internal/server` now at 76.1%; tests added for every 0 % function across `auth_test.go`, `api_test.go`, `filters_test.go`, `log_test.go`, and `extras_test.go`. + +--- + +## Phase 8c - Review Part II (entrypoint comparison) + +A second pass comparing [cmd/newrouter/main.go](../cmd/newrouter/main.go) against +the legacy [llama-swap.go](../llama-swap.go) + [proxy.New](../proxy/proxymanager.go#L104) +surfaced four more gaps, all in logger setup. + +**Gap 4 — `LogToStdout` config ignored -- Resolved.** + +`cmd/newrouter/main.go` previously hardcoded `proxyLog` / `upstreamLog` to +`os.Stdout`, and the old `muxlog()` helper built a Monitor that nothing wrote +into — so `logToStdout` had no effect and `/logs` (combined history) was always +empty. `server.NewLoggers` ([log.go](../internal/server/log.go)) now replicates +the legacy switch: `proxy` / `upstream` monitors feed `muxLog` (or `io.Discard`) +per `none` / `both` / `upstream` / `proxy`, so `muxLog` accumulates the combined +history. `server.New` takes `muxlog` as a parameter. The loggers outlive config +reloads, so a `LogToStdout` change requires a restart to take effect. + +**Gap 5 — `LogTimeFormat` config ignored -- Resolved.** + +`cmd/newrouter/main.go` now maps `cfg.LogTimeFormat` to a Go time layout via the +`logTimeFormats` table and applies it (alongside log level) to the proxy and +upstream monitors in `applyLogSettings`, re-applied on config reload. + +**Gap 6 — `LogRequests` deprecation warning missing.** + +The legacy [proxymanager.go:127](../proxy/proxymanager.go#L127) warns when the +deprecated `logRequests` config key is set. `cmd/newrouter` does not. Low +priority — left open. + +**Gap 7 — PID debug log missing -- Resolved.** + +`cmd/newrouter/main.go` now logs `PID: %d` at debug level after `applyLogSettings`, +matching [llama-swap.go:71](../llama-swap.go#L71). + +--- + +## Phase X (tbd) — Cutover + +- [ ] Swap `llama-swap.go` to delegate to `cmd/newrouter` (or rename newrouter to be the primary entrypoint) +- [ ] Update `Makefile` build targets +- [ ] Update docs / README references to the legacy binary +- [ ] Remove `proxy/proxymanager*.go` and `gin-gonic` dependency once nothing imports them +- [ ] Run `make test-all` and confirm concurrency suite still passes against the new entrypoint + +--- + +## Cross-cutting concerns to keep in mind + +- **Single body read**: legacy and newrouter both buffer the request body once. When adding filters (Phase 4c), make sure the buffered bytes flow through `Content-Length` / `transfer-encoding` cleanup as in [proxymanager.go:872](../proxy/proxymanager.go#L872). +- **Streaming flag in context**: legacy stashes `streaming` and `model` under `proxyCtxKey`. The new router uses `ModelKey` / `ModelIDKey` — pick one set of keys and use them consistently for metrics + log handlers. +- **Matrix vs Group divergence**: any handler that calls `swapProcessGroup` or `findGroupByModelName` in the legacy needs a matrix branch too. The new router's `Router` interface already abstracts this — preserve that abstraction rather than reintroducing the branch in every handler. +- **Shutdown ordering**: `httpServer.Shutdown` must drain inflight requests _before_ `Server.Shutdown` tears down processes, otherwise inflight requests 502. Current newrouter ordering at [main.go:87](../cmd/newrouter/main.go#L87) is correct — keep it. diff --git a/go.mod b/go.mod index 083cc5b..46995b6 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,9 @@ go 1.26.1 require ( github.com/billziss-gh/golib v0.2.0 + github.com/charmbracelet/bubbles v1.0.0 + github.com/charmbracelet/bubbletea v1.3.10 + github.com/charmbracelet/lipgloss v1.1.0 github.com/fxamacker/cbor/v2 v2.9.1 github.com/gin-gonic/gin v1.10.0 github.com/klauspost/compress v1.18.5 @@ -11,16 +14,26 @@ require ( github.com/stretchr/testify v1.11.1 github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 + golang.org/x/sync v0.20.0 gopkg.in/yaml.v3 v3.0.1 ) require ( + github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/bytedance/sonic v1.11.6 // indirect github.com/bytedance/sonic/loader v0.1.1 // indirect + github.com/charmbracelet/colorprofile v0.4.1 // indirect + github.com/charmbracelet/x/ansi v0.11.6 // indirect + github.com/charmbracelet/x/cellbuf v0.0.15 // indirect + github.com/charmbracelet/x/term v0.2.2 // indirect + github.com/clipperhouse/displaywidth v0.9.0 // indirect + github.com/clipperhouse/stringish v0.1.1 // indirect + github.com/clipperhouse/uax29/v2 v2.5.0 // indirect github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/ebitengine/purego v0.10.0 // indirect + github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-ole/go-ole v1.2.6 // indirect @@ -31,13 +44,20 @@ require ( github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect github.com/leodido/go-urn v1.4.0 // indirect + github.com/lucasb-eyer/go-colorful v1.3.0 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-localereader v0.0.1 // indirect + github.com/mattn/go-runewidth v0.0.19 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect + github.com/muesli/cancelreader v0.2.2 // indirect + github.com/muesli/termenv v0.16.0 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect + github.com/rivo/uniseg v0.4.7 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/tklauser/go-sysconf v0.3.16 // indirect @@ -45,6 +65,7 @@ require ( github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect github.com/x448/float16 v0.8.4 // indirect + github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect golang.org/x/arch v0.8.0 // indirect golang.org/x/crypto v0.45.0 // indirect diff --git a/go.sum b/go.sum index 131f9fb..843b8ba 100644 --- a/go.sum +++ b/go.sum @@ -1,9 +1,31 @@ +github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= +github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= github.com/billziss-gh/golib v0.2.0 h1:NyvcAQdfvM8xokKkKotiligKjKXzuQD4PPykg1nKc/8= github.com/billziss-gh/golib v0.2.0/go.mod h1:mZpUYANXZkDKSnyYbX9gfnyxwe0ddRhUtfXcsD5r8dw= github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= +github.com/charmbracelet/bubbles v1.0.0 h1:12J8/ak/uCZEMQ6KU7pcfwceyjLlWsDLAxB5fXonfvc= +github.com/charmbracelet/bubbles v1.0.0/go.mod h1:9d/Zd5GdnauMI5ivUIVisuEm3ave1XwXtD1ckyV6r3E= +github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw= +github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4= +github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk= +github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk= +github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY= +github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30= +github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8= +github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ= +github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI= +github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q= +github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk= +github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI= +github.com/clipperhouse/displaywidth v0.9.0 h1:Qb4KOhYwRiN3viMv1v/3cTBlz3AcAZX3+y9OLhMtAtA= +github.com/clipperhouse/displaywidth v0.9.0/go.mod h1:aCAAqTlh4GIVkhQnJpbL0T/WfcrJXHcj8C0yjYcjOZA= +github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs= +github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA= +github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U= +github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g= github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= @@ -13,6 +35,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/ebitengine/purego v0.10.0 h1:QIw4xfpWT6GWTzaW5XEKy3HXoqrJGx1ijYHzTF0/ISU= github.com/ebitengine/purego v0.10.0/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= +github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= +github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= github.com/fxamacker/cbor/v2 v2.9.1 h1:2rWm8B193Ll4VdjsJY28jxs70IdDsHRWgQYAI80+rMQ= github.com/fxamacker/cbor/v2 v2.9.1/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= @@ -47,21 +71,35 @@ github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZY github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag= +github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4= +github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88= +github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw= +github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI= +github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo= +github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= +github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= +github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= +github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU= github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= +github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= +github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/shirou/gopsutil/v4 v4.26.4 h1:B4SXVbcwTyrocPHEmWBC4uCYr4Xcu3MK1TXqbprAOWY= github.com/shirou/gopsutil/v4 v4.26.4/go.mod h1:LZ6ewCSkBqUpvSOf+LsTGnRinC6iaNUNMGBtDkJBaLQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -97,6 +135,8 @@ github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65E github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= @@ -104,10 +144,15 @@ golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= +golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= +golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= diff --git a/proxy/cache/cache.go b/internal/cache/cache.go similarity index 100% rename from proxy/cache/cache.go rename to internal/cache/cache.go diff --git a/proxy/cache/cache_test.go b/internal/cache/cache_test.go similarity index 100% rename from proxy/cache/cache_test.go rename to internal/cache/cache_test.go diff --git a/internal/chain/chain.go b/internal/chain/chain.go new file mode 100644 index 0000000..699d37b --- /dev/null +++ b/internal/chain/chain.go @@ -0,0 +1,63 @@ +// Package chain composes http.Handler middleware into a single handler. +// +// A Middleware wraps a downstream http.Handler and may run logic before or +// after delegating to it, or short-circuit by not calling next at all +// (e.g. auth failure, CORS preflight). +package chain + +import "net/http" + +// Middleware wraps an http.Handler with cross-cutting behavior. It receives +// the next handler in the chain and returns a handler that may call next, +// modify the request/response around it, or short-circuit. +type Middleware func(next http.Handler) http.Handler + +// Chain is a reusable middleware stack. Build it once with New (and optionally +// extend per-route with Append), then call Then to wrap each terminal handler +// when registering routes against an http.ServeMux: +// +// api := chain.New(authMW, corsMW) +// mux.Handle("/v1/chat/completions", api.Then(dispatch)) +// mux.Handle("/v1/embeddings", api.Append(filters).Then(dispatch)) +// +// Middlewares execute left-to-right: mws[0] runs first and may call into +// mws[1], and so on, with the terminal handler invoked last. A middleware +// that does not call next short-circuits the remainder of the chain. +// A zero Chain is valid and applies no middleware. +type Chain struct { + mws []Middleware +} + +// New returns a Chain that applies mws left-to-right around any terminal +// handler passed to Then. +func New(mws ...Middleware) Chain { + cp := make([]Middleware, len(mws)) + copy(cp, mws) + return Chain{mws: cp} +} + +// Append returns a new Chain with mws added after the existing middleware. +// The receiver is not modified, so a base Chain can be safely reused across +// multiple routes that each need different per-route additions. +func (c Chain) Append(mws ...Middleware) Chain { + out := make([]Middleware, 0, len(c.mws)+len(mws)) + out = append(out, c.mws...) + out = append(out, mws...) + return Chain{mws: out} +} + +// Then wraps final with the chain's middleware and returns the resulting +// handler, suitable for passing to http.ServeMux.Handle. With an empty chain, +// Then returns final unchanged. +func (c Chain) Then(final http.Handler) http.Handler { + h := final + for i := len(c.mws) - 1; i >= 0; i-- { + h = c.mws[i](h) + } + return h +} + +// ThenFunc is shorthand for Then(http.HandlerFunc(f)). +func (c Chain) ThenFunc(f http.HandlerFunc) http.Handler { + return c.Then(f) +} diff --git a/internal/chain/chain_test.go b/internal/chain/chain_test.go new file mode 100644 index 0000000..707afdd --- /dev/null +++ b/internal/chain/chain_test.go @@ -0,0 +1,205 @@ +package chain + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +// recordingMiddleware appends tag before calling next and "-after-"+tag after. +func recordingMiddleware(tag string, log *[]string) Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + *log = append(*log, tag) + next.ServeHTTP(w, r) + *log = append(*log, "after-"+tag) + }) + } +} + +func TestChain_HandlersExecuteInDeclaredOrder(t *testing.T) { + var log []string + final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log = append(log, "final") + }) + + h := New( + recordingMiddleware("a", &log), + recordingMiddleware("b", &log), + recordingMiddleware("c", &log), + ).Then(final) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + h.ServeHTTP(rec, req) + + want := []string{"a", "b", "c", "final", "after-c", "after-b", "after-a"} + if !equal(log, want) { + t.Fatalf("execution order mismatch:\n got: %v\nwant: %v", log, want) + } +} + +func TestChain_ShortCircuitsWhenMiddlewareDoesNotCallNext(t *testing.T) { + var log []string + final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log = append(log, "final") + }) + + gate := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log = append(log, "gate") + w.WriteHeader(http.StatusUnauthorized) + }) + } + + h := New( + recordingMiddleware("outer", &log), + gate, + recordingMiddleware("inner", &log), + ).Then(final) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + h.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnauthorized { + t.Fatalf("status: got %d, want %d", rec.Code, http.StatusUnauthorized) + } + want := []string{"outer", "gate", "after-outer"} + if !equal(log, want) { + t.Fatalf("short-circuit order mismatch:\n got: %v\nwant: %v", log, want) + } +} + +func TestChain_EarlyWritesAreVisibleToLaterMiddleware(t *testing.T) { + header := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Set-By", "outer") + _, _ = io.WriteString(w, "outer:") + next.ServeHTTP(w, r) + }) + } + inner := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // The outer middleware already set the header; we should see it. + if got := w.Header().Get("X-Set-By"); got != "outer" { + _, _ = io.WriteString(w, "missing-header;") + } + _, _ = io.WriteString(w, "inner:") + next.ServeHTTP(w, r) + }) + } + final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "final") + }) + + h := New(header, inner).Then(final) + + rec := httptest.NewRecorder() + h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) + + body, _ := io.ReadAll(rec.Body) + if got := string(body); !strings.Contains(got, "outer:inner:final") { + t.Fatalf("body: got %q, want it to contain %q", got, "outer:inner:final") + } + if got := rec.Header().Get("X-Set-By"); got != "outer" { + t.Fatalf("header X-Set-By: got %q, want %q", got, "outer") + } +} + +func TestChain_ReusableAcrossRoutesViaThen(t *testing.T) { + var log []string + base := New( + recordingMiddleware("auth", &log), + recordingMiddleware("cors", &log), + ) + + mux := http.NewServeMux() + mux.Handle("/a", base.ThenFunc(func(w http.ResponseWriter, r *http.Request) { + log = append(log, "handler-a") + })) + mux.Handle("/b", base.ThenFunc(func(w http.ResponseWriter, r *http.Request) { + log = append(log, "handler-b") + })) + + srv := httptest.NewServer(mux) + defer srv.Close() + + for _, path := range []string{"/a", "/b"} { + resp, err := http.Get(srv.URL + path) + if err != nil { + t.Fatalf("GET %s: %v", path, err) + } + resp.Body.Close() + } + + want := []string{ + "auth", "cors", "handler-a", "after-cors", "after-auth", + "auth", "cors", "handler-b", "after-cors", "after-auth", + } + if !equal(log, want) { + t.Fatalf("reusable chain order mismatch:\n got: %v\nwant: %v", log, want) + } +} + +func TestChain_AppendDoesNotMutateReceiver(t *testing.T) { + var log []string + base := New(recordingMiddleware("base", &log)) + extended := base.Append(recordingMiddleware("extra", &log)) + + final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log = append(log, "final") + }) + + // Run extended first to surface any aliasing of the underlying slice. + rec := httptest.NewRecorder() + extended.Then(final).ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) + + rec = httptest.NewRecorder() + base.Then(final).ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) + + want := []string{ + "base", "extra", "final", "after-extra", "after-base", + "base", "final", "after-base", + } + if !equal(log, want) { + t.Fatalf("Append must not mutate the receiver:\n got: %v\nwant: %v", log, want) + } +} + +func TestChain_ZeroValueAndEmptyThenAreIdentity(t *testing.T) { + final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTeapot) + }) + + for name, c := range map[string]Chain{ + "zero": {}, + "empty": New(), + } { + t.Run(name, func(t *testing.T) { + h := c.Then(final) + if _, ok := h.(http.HandlerFunc); !ok { + t.Fatalf("expected http.HandlerFunc identity, got %T", h) + } + rec := httptest.NewRecorder() + h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) + if rec.Code != http.StatusTeapot { + t.Fatalf("status: got %d, want %d", rec.Code, http.StatusTeapot) + } + }) + } +} + +func equal(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/proxy/config/config.go b/internal/config/config.go similarity index 99% rename from proxy/config/config.go rename to internal/config/config.go index e6ff5f1..103daff 100644 --- a/proxy/config/config.go +++ b/internal/config/config.go @@ -272,6 +272,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { nextPort := config.StartPort for _, modelId := range modelIds { modelConfig := config.Models[modelId] + modelConfig.HealthCheckTimeout = config.HealthCheckTimeout // Strip comments from command fields modelConfig.Cmd = StripComments(modelConfig.Cmd) diff --git a/proxy/config/config_posix_test.go b/internal/config/config_posix_test.go similarity index 79% rename from proxy/config/config_posix_test.go rename to internal/config/config_posix_test.go index f5d0070..5ac9b0a 100644 --- a/proxy/config/config_posix_test.go +++ b/internal/config/config_posix_test.go @@ -189,42 +189,46 @@ groups: SendLoadingState: false, Models: map[string]ModelConfig{ "model1": { - Cmd: "path/to/cmd --arg1 one", - Proxy: "http://localhost:8080", - Aliases: []string{"m1", "model-one"}, - Env: []string{"VAR1=value1", "VAR2=value2"}, - CheckEndpoint: "/health", - Name: "Model 1", - Description: "This is model 1", - SendLoadingState: &modelLoadingState, - Timeouts: defaultTimeout, + Cmd: "path/to/cmd --arg1 one", + Proxy: "http://localhost:8080", + Aliases: []string{"m1", "model-one"}, + Env: []string{"VAR1=value1", "VAR2=value2"}, + CheckEndpoint: "/health", + Name: "Model 1", + Description: "This is model 1", + SendLoadingState: &modelLoadingState, + Timeouts: defaultTimeout, + HealthCheckTimeout: 15, }, "model2": { - Cmd: "path/to/server --arg1 one", - Proxy: "http://localhost:8081", - Aliases: []string{"m2"}, - Env: []string{}, - CheckEndpoint: "/", - SendLoadingState: &modelLoadingState, - Timeouts: defaultTimeout, + Cmd: "path/to/server --arg1 one", + Proxy: "http://localhost:8081", + Aliases: []string{"m2"}, + Env: []string{}, + CheckEndpoint: "/", + SendLoadingState: &modelLoadingState, + Timeouts: defaultTimeout, + HealthCheckTimeout: 15, }, "model3": { - Cmd: "path/to/cmd --arg1 one", - Proxy: "http://localhost:8081", - Aliases: []string{"mthree"}, - Env: []string{}, - CheckEndpoint: "/", - SendLoadingState: &modelLoadingState, - Timeouts: defaultTimeout, + Cmd: "path/to/cmd --arg1 one", + Proxy: "http://localhost:8081", + Aliases: []string{"mthree"}, + Env: []string{}, + CheckEndpoint: "/", + SendLoadingState: &modelLoadingState, + Timeouts: defaultTimeout, + HealthCheckTimeout: 15, }, "model4": { - Cmd: "path/to/cmd --arg1 one", - Proxy: "http://localhost:8082", - CheckEndpoint: "/", - Aliases: []string{}, - Env: []string{}, - SendLoadingState: &modelLoadingState, - Timeouts: defaultTimeout, + Cmd: "path/to/cmd --arg1 one", + Proxy: "http://localhost:8082", + CheckEndpoint: "/", + Aliases: []string{}, + Env: []string{}, + SendLoadingState: &modelLoadingState, + Timeouts: defaultTimeout, + HealthCheckTimeout: 15, }, }, HealthCheckTimeout: 15, diff --git a/proxy/config/config_test.go b/internal/config/config_test.go similarity index 100% rename from proxy/config/config_test.go rename to internal/config/config_test.go diff --git a/proxy/config/config_windows_test.go b/internal/config/config_windows_test.go similarity index 77% rename from proxy/config/config_windows_test.go rename to internal/config/config_windows_test.go index 1777bdf..dcba2d7 100644 --- a/proxy/config/config_windows_test.go +++ b/internal/config/config_windows_test.go @@ -176,44 +176,48 @@ groups: SendLoadingState: false, Models: map[string]ModelConfig{ "model1": { - Cmd: "path/to/cmd --arg1 one", - CmdStop: "taskkill /f /t /pid ${PID}", - Proxy: "http://localhost:8080", - Aliases: []string{"m1", "model-one"}, - Env: []string{"VAR1=value1", "VAR2=value2"}, - CheckEndpoint: "/health", - SendLoadingState: &modelLoadingState, - Timeouts: defaultTimeout, + Cmd: "path/to/cmd --arg1 one", + CmdStop: "taskkill /f /t /pid ${PID}", + Proxy: "http://localhost:8080", + Aliases: []string{"m1", "model-one"}, + Env: []string{"VAR1=value1", "VAR2=value2"}, + CheckEndpoint: "/health", + SendLoadingState: &modelLoadingState, + Timeouts: defaultTimeout, + HealthCheckTimeout: 15, }, "model2": { - Cmd: "path/to/server --arg1 one", - CmdStop: "taskkill /f /t /pid ${PID}", - Proxy: "http://localhost:8081", - Aliases: []string{"m2"}, - Env: []string{}, - CheckEndpoint: "/", - SendLoadingState: &modelLoadingState, - Timeouts: defaultTimeout, + Cmd: "path/to/server --arg1 one", + CmdStop: "taskkill /f /t /pid ${PID}", + Proxy: "http://localhost:8081", + Aliases: []string{"m2"}, + Env: []string{}, + CheckEndpoint: "/", + SendLoadingState: &modelLoadingState, + Timeouts: defaultTimeout, + HealthCheckTimeout: 15, }, "model3": { - Cmd: "path/to/cmd --arg1 one", - CmdStop: "taskkill /f /t /pid ${PID}", - Proxy: "http://localhost:8081", - Aliases: []string{"mthree"}, - Env: []string{}, - CheckEndpoint: "/", - SendLoadingState: &modelLoadingState, - Timeouts: defaultTimeout, + Cmd: "path/to/cmd --arg1 one", + CmdStop: "taskkill /f /t /pid ${PID}", + Proxy: "http://localhost:8081", + Aliases: []string{"mthree"}, + Env: []string{}, + CheckEndpoint: "/", + SendLoadingState: &modelLoadingState, + Timeouts: defaultTimeout, + HealthCheckTimeout: 15, }, "model4": { - Cmd: "path/to/cmd --arg1 one", - CmdStop: "taskkill /f /t /pid ${PID}", - Proxy: "http://localhost:8082", - CheckEndpoint: "/", - Aliases: []string{}, - Env: []string{}, - SendLoadingState: &modelLoadingState, - Timeouts: defaultTimeout, + Cmd: "path/to/cmd --arg1 one", + CmdStop: "taskkill /f /t /pid ${PID}", + Proxy: "http://localhost:8082", + CheckEndpoint: "/", + Aliases: []string{}, + Env: []string{}, + SendLoadingState: &modelLoadingState, + Timeouts: defaultTimeout, + HealthCheckTimeout: 15, }, }, HealthCheckTimeout: 15, diff --git a/proxy/config/filters.go b/internal/config/filters.go similarity index 100% rename from proxy/config/filters.go rename to internal/config/filters.go diff --git a/proxy/config/filters_test.go b/internal/config/filters_test.go similarity index 100% rename from proxy/config/filters_test.go rename to internal/config/filters_test.go diff --git a/proxy/config/macro_in_macro_test.go b/internal/config/macro_in_macro_test.go similarity index 100% rename from proxy/config/macro_in_macro_test.go rename to internal/config/macro_in_macro_test.go diff --git a/proxy/config/matrix.go b/internal/config/matrix.go similarity index 100% rename from proxy/config/matrix.go rename to internal/config/matrix.go diff --git a/proxy/config/matrix_dsl.go b/internal/config/matrix_dsl.go similarity index 100% rename from proxy/config/matrix_dsl.go rename to internal/config/matrix_dsl.go diff --git a/proxy/config/matrix_dsl_test.go b/internal/config/matrix_dsl_test.go similarity index 100% rename from proxy/config/matrix_dsl_test.go rename to internal/config/matrix_dsl_test.go diff --git a/proxy/config/matrix_test.go b/internal/config/matrix_test.go similarity index 100% rename from proxy/config/matrix_test.go rename to internal/config/matrix_test.go diff --git a/proxy/config/model_config.go b/internal/config/model_config.go similarity index 97% rename from proxy/config/model_config.go rename to internal/config/model_config.go index 108e79a..68c0ffc 100644 --- a/proxy/config/model_config.go +++ b/internal/config/model_config.go @@ -54,6 +54,9 @@ type ModelConfig struct { // Timeout settings for proxy connections Timeouts TimeoutsConfig `yaml:"timeouts"` + + // Copy of HealthCheckTimeout from global config + HealthCheckTimeout int `yaml:"healthCheckTimeout"` } func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { diff --git a/proxy/config/model_config_test.go b/internal/config/model_config_test.go similarity index 100% rename from proxy/config/model_config_test.go rename to internal/config/model_config_test.go diff --git a/proxy/config/peer.go b/internal/config/peer.go similarity index 100% rename from proxy/config/peer.go rename to internal/config/peer.go diff --git a/proxy/config/peer_test.go b/internal/config/peer_test.go similarity index 100% rename from proxy/config/peer_test.go rename to internal/config/peer_test.go diff --git a/proxy/config/performance.go b/internal/config/performance.go similarity index 100% rename from proxy/config/performance.go rename to internal/config/performance.go diff --git a/proxy/config/performance_test.go b/internal/config/performance_test.go similarity index 100% rename from proxy/config/performance_test.go rename to internal/config/performance_test.go diff --git a/event/README.md b/internal/event/README.md similarity index 100% rename from event/README.md rename to internal/event/README.md diff --git a/event/default.go b/internal/event/default.go similarity index 100% rename from event/default.go rename to internal/event/default.go diff --git a/event/default_test.go b/internal/event/default_test.go similarity index 100% rename from event/default_test.go rename to internal/event/default_test.go diff --git a/event/event.go b/internal/event/event.go similarity index 100% rename from event/event.go rename to internal/event/event.go diff --git a/event/event_test.go b/internal/event/event_test.go similarity index 100% rename from event/event_test.go rename to internal/event/event_test.go diff --git a/internal/logmon/logging.go b/internal/logmon/logging.go index 5533187..5bdeb1b 100644 --- a/internal/logmon/logging.go +++ b/internal/logmon/logging.go @@ -8,7 +8,7 @@ import ( "sync" "time" - "github.com/mostlygeek/llama-swap/event" + "github.com/mostlygeek/llama-swap/internal/event" ) const DataEventID = 0x04 diff --git a/internal/perf/monitor.go b/internal/perf/monitor.go index 6b7cdfe..1e2c1c7 100644 --- a/internal/perf/monitor.go +++ b/internal/perf/monitor.go @@ -6,9 +6,9 @@ import ( "sync" "time" + "github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/ring" - "github.com/mostlygeek/llama-swap/proxy/config" ) var ( diff --git a/internal/perf/monitor_test.go b/internal/perf/monitor_test.go index 6ae5363..5a500a0 100644 --- a/internal/perf/monitor_test.go +++ b/internal/perf/monitor_test.go @@ -6,8 +6,8 @@ import ( "testing" "time" + "github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/logmon" - "github.com/mostlygeek/llama-swap/proxy/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/internal/perf/prometheus_test.go b/internal/perf/prometheus_test.go index dec57cf..b246516 100644 --- a/internal/perf/prometheus_test.go +++ b/internal/perf/prometheus_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/mostlygeek/llama-swap/proxy/config" + "github.com/mostlygeek/llama-swap/internal/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/internal/process/helpers_test.go b/internal/process/helpers_test.go new file mode 100644 index 0000000..65e493b --- /dev/null +++ b/internal/process/helpers_test.go @@ -0,0 +1,49 @@ +package process + +import ( + "fmt" + "net" + "os" + "path/filepath" + "runtime" + "strings" + "testing" +) + +var simpleResponderPath string + +func skipIfNoSimpleResponder(t *testing.T) { + t.Helper() + if _, err := os.Stat(simpleResponderPath); os.IsNotExist(err) { + t.Skipf("simple-responder not found at %s, run `make simple-responder`", simpleResponderPath) + } +} + +func TestMain(m *testing.M) { + goos := runtime.GOOS + goarch := runtime.GOARCH + if goos == "windows" { + simpleResponderPath = filepath.Join("..", "..", "build", "simple-responder.exe") + } else { + simpleResponderPath = filepath.Join("..", "..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch)) + } + m.Run() +} + +func getFreePort(t *testing.T) int { + t.Helper() + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("getFreePort: %v", err) + } + defer l.Close() + return l.Addr().(*net.TCPAddr).Port +} + +func simpleResponderCmd(t *testing.T, args ...string) (string, int) { + port := getFreePort(t) + cmdPath := filepath.ToSlash(simpleResponderPath) + base := []string{cmdPath, fmt.Sprintf("-port %d", port)} + base = append(base, args...) + return strings.Join(base, " "), port +} diff --git a/internal/process/process.go b/internal/process/process.go new file mode 100644 index 0000000..f91a1ac --- /dev/null +++ b/internal/process/process.go @@ -0,0 +1,49 @@ +package process + +import ( + "context" + "net/http" + "time" + + "github.com/mostlygeek/llama-swap/internal/logmon" +) + +type ProcessState string + +const ( + StateStopped ProcessState = ProcessState("stopped") + StateStarting ProcessState = ProcessState("starting") + StateReady ProcessState = ProcessState("ready") + StateStopping ProcessState = ProcessState("stopping") + + // process is shutdown and will not be restarted + StateShutdown ProcessState = ProcessState("shutdown") +) + +type Process interface { + // Run starts the process blocks until the process is terminated. + // The timeout parameter controls how long to wait for the process to get + // to a ready state to process traffic + Run(timeout time.Duration) error + + // WaitReady blocks until the process is ready to serve requests + // or the context is cancelled. It returns nil when the process is ready + WaitReady(context.Context) error + + // Stop blocks until the process has terminated. It returns nil when + // the process terminated as expected (exit 0) + Stop(timeout time.Duration) error + + // State returns the current state of the process + // Note: this is a snapshot of the state at the time of the call + // and may change at any time after the call returns. + State() ProcessState + + // ServeHTTP forwards requests to the underlying process + // Calling it when the process is not ready will result in a + // 503 response with a body indicating it is a llama-swap-error + ServeHTTP(http.ResponseWriter, *http.Request) + + // Logger returns the monitor that captures this process's stdout/stderr. + Logger() *logmon.Monitor +} diff --git a/internal/process/process_command.go b/internal/process/process_command.go new file mode 100644 index 0000000..ed888ba --- /dev/null +++ b/internal/process/process_command.go @@ -0,0 +1,568 @@ +package process + +import ( + "context" + "fmt" + "net" + "net/http" + "net/http/httptest" + "net/http/httputil" + "net/url" + "os/exec" + "strings" + "sync/atomic" + "syscall" + "time" + + "github.com/mostlygeek/llama-swap/internal/config" + "github.com/mostlygeek/llama-swap/internal/event" + "github.com/mostlygeek/llama-swap/internal/logmon" + "github.com/mostlygeek/llama-swap/internal/shared" +) + +var ErrStartAborted = fmt.Errorf("aborted") + +type runReq struct { + timeout time.Duration + respond chan error +} + +type stopReq struct { + timeout time.Duration + respond chan error +} + +type waitReadyReq struct { + respond chan error +} + +type startResult struct { + cmd *exec.Cmd + cmdDone chan struct{} + handlerFn http.HandlerFunc + err error +} + +type ProcessCommand struct { + id string + config config.ModelConfig + parentCtx context.Context + + processLogger *logmon.Monitor + proxyLogger *logmon.Monitor + + runCh chan runReq + stopCh chan stopReq + waitReadyCh chan waitReadyReq + + // current ProcessState. Written only by run(); read by State() via atomic load. + state atomic.Value + + // stores the active reverse-proxy handler when the process is running. + // Written only by run(); read by ServeHTTP via atomic load. + handler atomic.Pointer[http.HandlerFunc] + + lastUse atomic.Int64 // unix nano timestamp of last ServeHTTP completion + inflight atomic.Int64 // current in-flight ServeHTTP calls +} + +var _ Process = (*ProcessCommand)(nil) + +func New( + parentCtx context.Context, + id string, + conf config.ModelConfig, + processLogger *logmon.Monitor, + proxyLogger *logmon.Monitor, +) (*ProcessCommand, error) { + p := &ProcessCommand{ + id: id, + config: conf, + parentCtx: parentCtx, + processLogger: processLogger, + proxyLogger: proxyLogger, + + runCh: make(chan runReq), + stopCh: make(chan stopReq), + waitReadyCh: make(chan waitReadyReq), + } + p.state.Store(StateStopped) + + go p.run() + return p, nil +} + +func (p *ProcessCommand) Logger() *logmon.Monitor { return p.processLogger } + +// run is the single-writer goroutine that owns all mutable lifecycle state +// (current ProcessState, the running *exec.Cmd, the active reverse-proxy +// handler, and the list of WaitReady subscribers). Every public method +// (Run / Stop / State / WaitReady) is a thin client that sends a request on +// one of the channels below and waits for a response — this funnels concurrent +// callers through a single serialization point so the state machine never +// observes a race. +func (p *ProcessCommand) run() { + // Mutable state — only read/written from this goroutine. ServeHTTP reads + // p.handler concurrently, which is why handler is an atomic.Pointer. + // p.state mirrors `state` so State() can observe transitions; setState + // writes both. + state := StateStopped + setState := func(s ProcessState) { + old := state + state = s + p.state.Store(s) + if old != s { + event.Emit(shared.ProcessStateChangeEvent{ + ProcessName: p.id, + OldState: string(old), + NewState: string(s), + }) + } + } + var ( + cmd *exec.Cmd + cmdDone <-chan struct{} + readyWaiters []waitReadyReq + // runResp parks the in-flight Run caller's response channel. The + // interface contract is that Run blocks until the process is + // terminated, so we hold this until Stop, parentCtx, or an + // upstream exit unblocks it via respondRun. + runResp chan<- error + ) + + // notifyWaiters wakes every blocked WaitReady caller with the given result. + // Used on transitions out of StateStarting (ready, failed, aborted, or + // shutdown) — anything that resolves the "is it ready yet?" question. + notifyWaiters := func(err error) { + for _, w := range readyWaiters { + select { + case w.respond <- err: + default: + } + } + readyWaiters = nil + } + + // respondRun delivers the final Run result, if a Run caller is parked. + respondRun := func(err error) { + if runResp != nil { + runResp <- err + runResp = nil + } + } + + for { + select { + // Shutdown: parent context cancelled. Tear down any running process, + // wake any pending WaitReady callers with an error, then exit the + // goroutine permanently. Subsequent public-method calls will fail + // because parentCtx.Done() unblocks their send-side selects. + case <-p.parentCtx.Done(): + // Mark shutdown before killProcess so concurrent State() readers + // stop treating this process as ready while the (possibly slow) + // teardown is in progress. + setState(StateShutdown) + if cmd != nil { + p.handler.Store(nil) + p.killProcess(cmd, cmdDone, 100*time.Millisecond) + cmd = nil + cmdDone = nil + } + notifyWaiters(fmt.Errorf("[%s] shutdown", p.id)) + respondRun(fmt.Errorf("[%s] shutdown", p.id)) + return + + // Upstream exited on its own (not via Stop). Drop handler state, + // transition to Stopped, and unblock the parked Run caller. + // cmdDone is nil while no process is running, so this case is + // dormant outside of StateReady. + case <-cmdDone: + cmd = nil + cmdDone = nil + p.handler.Store(nil) + setState(StateStopped) + respondRun(fmt.Errorf("[%s] upstream exited unexpectedly", p.id)) + + // WaitReady: if we're already in a terminal-for-this-question state, + // respond immediately; otherwise queue the caller and let a future + // state transition wake them via notifyWaiters. + case req := <-p.waitReadyCh: + switch state { + case StateReady: + req.respond <- nil + case StateShutdown: + req.respond <- fmt.Errorf("[%s] shutdown", p.id) + default: + readyWaiters = append(readyWaiters, req) + } + + // Run: start the upstream process. Only valid from StateStopped. + // doStart can take a long time (health-check polling), so it runs in + // a separate goroutine and we wait on resultCh. While waiting we also + // listen for an incoming Stop — that's how callers cancel an in-flight + // start. + case req := <-p.runCh: + if state != StateStopped { + req.respond <- fmt.Errorf("[%s] could not be started in %s state", p.id, state) + continue + } + setState(StateStarting) + + startCtx, cancelStart := context.WithCancel(context.Background()) + resultCh := make(chan startResult, 1) + go func() { + resultCh <- p.doStart(startCtx, req.timeout) + }() + + // pendingStop holds a Stop request that arrived mid-start, so we + // can respond to it AFTER we've finished tearing the start down. + var pendingStop *stopReq + select { + // doStart finished on its own — either successfully (latch + // cmd/handler and move to Ready) or with an error (back to + // Stopped). Either way wake WaitReady subscribers and reply + // to the Run caller. + case res := <-resultCh: + if res.err == nil { + cmd = res.cmd + cmdDone = res.cmdDone + fn := res.handlerFn + p.handler.Store(&fn) + setState(StateReady) + notifyWaiters(nil) + // Park the Run response — Run blocks until the process + // terminates, so we only fire this when Stop, parentCtx, + // or the upstream exit takes the process down. + runResp = req.respond + + // Start TTL goroutine if configured — self-terminates + // when state leaves StateReady. + if p.config.UnloadAfter > 0 { + ttlDuration := time.Duration(p.config.UnloadAfter) * time.Second + go func() { + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + for range ticker.C { + if p.State() != StateReady { + return + } + if p.inflight.Load() != 0 { + continue + } + if time.Since(time.Unix(0, p.lastUse.Load())) > ttlDuration { + p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.id, p.config.UnloadAfter) + p.Stop(10 * time.Second) + return + } + } + }() + } + } else { + setState(StateStopped) + notifyWaiters(res.err) + req.respond <- res.err + } + + // Stop arrived while doStart was still running. Cancel the + // start context to abort it, then wait for doStart to return. + // If doStart had already crossed the finish line before + // cancellation took effect, it returns a live cmd that we + // must kill ourselves. The Run caller gets ErrAbort; the Stop + // caller is parked in pendingStop and answered below. + case stop := <-p.stopCh: + cancelStart() + res := <-resultCh + if res.cmd != nil { + p.killProcess(res.cmd, res.cmdDone, stop.timeout) + } + setState(StateStopped) + notifyWaiters(ErrStartAborted) + req.respond <- ErrStartAborted + pendingStop = &stop + + // Parent context cancelled (e.g. config reload) while doStart + // was still running. Stop() returns early when parentCtx is + // done and never sends on stopCh, so we must handle shutdown + // here to avoid leaving doStart running indefinitely. + case <-p.parentCtx.Done(): + cancelStart() + // Mark shutdown before tearing the process down: killProcess + // may block (e.g. taskkill on Windows is slow to spawn), and + // callers observing State() should see StateShutdown promptly + // rather than a stale StateStarting. + setState(StateShutdown) + res := <-resultCh + if res.cmd != nil { + p.killProcess(res.cmd, res.cmdDone, 100*time.Millisecond) + } + notifyWaiters(fmt.Errorf("[%s] shutdown", p.id)) + respondRun(fmt.Errorf("[%s] shutdown", p.id)) + return + } + // cancelStart is idempotent; calling it again here ensures the + // context is released even on the success path (govet leak check). + cancelStart() + if pendingStop != nil { + pendingStop.respond <- nil + } + + // Stop: tear down a running process. + case stop := <-p.stopCh: + if cmd != nil { + setState(StateStopping) + p.killProcess(cmd, cmdDone, stop.timeout) + cmd = nil + cmdDone = nil + p.handler.Store(nil) + } + // Stop is a no-op (and not an error) when already Stopped — this + // is what makes it idempotent for callers that don't track state. + setState(StateStopped) + respondRun(nil) + stop.respond <- nil + } + } +} + +func (p *ProcessCommand) doStart(startCtx context.Context, healthCheckTimeout time.Duration) startResult { + if p.config.Proxy == "" { + return startResult{err: fmt.Errorf("upstream proxy missing")} + } + + args, err := p.config.SanitizedCommand() + if err != nil { + return startResult{err: fmt.Errorf("unable to get sanitized command: %w", err)} + } + + proxyURL, err := url.Parse(p.config.Proxy) + if err != nil { + return startResult{err: fmt.Errorf("invalid proxy URL %q: %w", p.config.Proxy, err)} + } + + reverseProxy := httputil.NewSingleHostReverseProxy(proxyURL) + reverseProxy.Transport = &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: time.Duration(p.config.Timeouts.Connect) * time.Second, + KeepAlive: time.Duration(p.config.Timeouts.KeepAlive) * time.Second, + }).DialContext, + TLSHandshakeTimeout: time.Duration(p.config.Timeouts.TLSHandshake) * time.Second, + ResponseHeaderTimeout: time.Duration(p.config.Timeouts.ResponseHeader) * time.Second, + ExpectContinueTimeout: time.Duration(p.config.Timeouts.ExpectContinue) * time.Second, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: time.Duration(p.config.Timeouts.IdleConn) * time.Second, + } + reverseProxy.ModifyResponse = func(resp *http.Response) error { + if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") { + resp.Header.Set("X-Accel-Buffering", "no") + } + return nil + } + // httputil.ReverseProxy panics with http.ErrAbortHandler when the upstream + // disconnects after response headers have been sent. Recover here so the + // streaming termination is treated as a normal client/upstream disconnect. + // see: https://github.com/golang/go/issues/23643 + handlerFn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if rec := recover(); rec != nil { + if rec == http.ErrAbortHandler { + p.proxyLogger.Infof("<%s> recovered from upstream disconnection during streaming", p.id) + } else { + p.proxyLogger.Warnf("<%s> recovered from panic: %v", p.id, rec) + } + } + }() + reverseProxy.ServeHTTP(w, r) + }) + + cmd := exec.Command(args[0], args[1:]...) + cmd.Stderr = p.processLogger + cmd.Stdout = p.processLogger + cmd.Env = append(cmd.Environ(), p.config.Env...) + setProcAttributes(cmd) + + p.proxyLogger.Debugf("<%s> Executing start command: %s, env: %s", p.id, strings.Join(args, " "), strings.Join(p.config.Env, ", ")) + + cmdDone := make(chan struct{}) + if err := cmd.Start(); err != nil { + return startResult{err: fmt.Errorf("failed to start command '%s': %w", strings.Join(args, " "), err)} + } + + go func() { + waitErr := cmd.Wait() + if exitErr, ok := waitErr.(*exec.ExitError); ok { + p.proxyLogger.Debugf("<%s> process exited: code=%d, err=%v", p.id, exitErr.ExitCode(), waitErr) + } else if waitErr != nil { + p.proxyLogger.Debugf("<%s> process exited with error: %v", p.id, waitErr) + } else { + p.proxyLogger.Debugf("<%s> process exited cleanly", p.id) + } + close(cmdDone) + }() + + if startCtx.Err() != nil { + p.killProcess(cmd, cmdDone, 5*time.Second) + return startResult{err: ErrStartAborted} + } + + checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint) + if checkEndpoint == "none" { + return startResult{cmd: cmd, cmdDone: cmdDone, handlerFn: handlerFn} + } + + // Wait 250ms for the command to start up before health checking + select { + case <-startCtx.Done(): + p.killProcess(cmd, cmdDone, 5*time.Second) + return startResult{err: ErrStartAborted} + case <-time.After(250 * time.Millisecond): + } + + deadline := time.Now().Add(healthCheckTimeout) + for { + select { + case <-startCtx.Done(): + p.killProcess(cmd, cmdDone, 5*time.Second) + return startResult{err: ErrStartAborted} + case <-cmdDone: + return startResult{err: fmt.Errorf("upstream command exited prematurely")} + default: + } + + if time.Now().After(deadline) { + p.killProcess(cmd, cmdDone, 5*time.Second) + return startResult{err: fmt.Errorf("health check timed out after %v", healthCheckTimeout)} + } + + req, _ := http.NewRequestWithContext(startCtx, "GET", p.config.CheckEndpoint, nil) + rr := httptest.NewRecorder() + reverseProxy.ServeHTTP(rr, req) + resp := rr.Result() + resp.Body.Close() + if resp.StatusCode == http.StatusOK { + p.proxyLogger.Infof("<%s> Health check passed on %s%s", p.id, p.config.Proxy, p.config.CheckEndpoint) + break + } else if startCtx.Err() != nil { + p.killProcess(cmd, cmdDone, 5*time.Second) + return startResult{err: ErrStartAborted} + } + + select { + case <-startCtx.Done(): + p.killProcess(cmd, cmdDone, 5*time.Second) + return startResult{err: ErrStartAborted} + case <-cmdDone: + return startResult{err: fmt.Errorf("upstream command exited prematurely")} + case <-time.After(time.Second): + } + } + + return startResult{cmd: cmd, cmdDone: cmdDone, handlerFn: handlerFn} +} + +func (p *ProcessCommand) killProcess(cmd *exec.Cmd, cmdDone <-chan struct{}, gracefulTimeout time.Duration) { + if cmd == nil || cmd.Process == nil { + return + } + + if p.config.CmdStop != "" { + stopArgs, err := config.SanitizeCommand( + strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", cmd.Process.Pid)), + ) + if err == nil { + stopCmd := exec.Command(stopArgs[0], stopArgs[1:]...) + stopCmd.Env = cmd.Env + setProcAttributes(stopCmd) + stopCmd.Run() + } else { + cmd.Process.Signal(syscall.SIGTERM) + } + } else { + cmd.Process.Signal(syscall.SIGTERM) + } + + timer := time.NewTimer(gracefulTimeout) + defer timer.Stop() + + select { + case <-cmdDone: + case <-timer.C: + cmd.Process.Kill() + <-cmdDone + } +} + +func (p *ProcessCommand) ID() string { + return p.id +} + +func (p *ProcessCommand) Run(timeout time.Duration) error { + req := runReq{ + timeout: timeout, + respond: make(chan error, 1), + } + select { + case p.runCh <- req: + case <-p.parentCtx.Done(): + return fmt.Errorf("[%s] shutdown", p.id) + } + select { + case err := <-req.respond: + return err + case <-p.parentCtx.Done(): + return fmt.Errorf("[%s] shutdown", p.id) + } +} + +func (p *ProcessCommand) WaitReady(ctx context.Context) error { + req := waitReadyReq{respond: make(chan error, 1)} + select { + case p.waitReadyCh <- req: + case <-ctx.Done(): + return ctx.Err() + case <-p.parentCtx.Done(): + return fmt.Errorf("[%s] shutdown", p.id) + } + select { + case err := <-req.respond: + return err + case <-ctx.Done(): + return ctx.Err() + } +} + +func (p *ProcessCommand) Stop(timeout time.Duration) error { + req := stopReq{ + timeout: timeout, + respond: make(chan error, 1), + } + select { + case p.stopCh <- req: + case <-p.parentCtx.Done(): + return fmt.Errorf("[%s] shutdown", p.id) + } + return <-req.respond +} + +func (p *ProcessCommand) State() ProcessState { + if s, ok := p.state.Load().(ProcessState); ok { + return s + } + return StateStopped +} + +func (p *ProcessCommand) ServeHTTP(w http.ResponseWriter, r *http.Request) { + fn := p.handler.Load() + if fn == nil { + http.Error(w, fmt.Sprintf("llama-swap-error: [%s] process is not ready", p.id), http.StatusServiceUnavailable) + return + } + p.inflight.Add(1) + defer func() { + p.lastUse.Store(time.Now().UnixNano()) + p.inflight.Add(-1) + }() + (*fn)(w, r) +} diff --git a/internal/process/process_command_test.go b/internal/process/process_command_test.go new file mode 100644 index 0000000..a52051e --- /dev/null +++ b/internal/process/process_command_test.go @@ -0,0 +1,646 @@ +package process + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "runtime" + "strings" + "sync" + "testing" + "time" + + "github.com/mostlygeek/llama-swap/internal/config" + "github.com/mostlygeek/llama-swap/internal/logmon" +) + +const ( + testStartTimeout = 3 * time.Second + testStopTimeout = 2 * time.Second + testReturnTimeout = 1 * time.Second + testPollInterval = 20 * time.Millisecond + testLogPollInterval = 10 * time.Millisecond +) + +func newProcessCommand(t *testing.T, conf config.ModelConfig) *ProcessCommand { + t.Helper() + logger := logmon.NewWriter(io.Discard) + p, err := New(context.Background(), t.Name(), conf, logger, logger) + if err != nil { + t.Fatalf("New: %v", err) + } + return p +} + +// runAsync starts Run in a goroutine and waits until the process is ready, +// matching the new interface contract where Run blocks until the process is +// terminated. Returns a channel that delivers Run's eventual error. +func runAsync(t *testing.T, p *ProcessCommand) <-chan error { + t.Helper() + ch := make(chan error, 1) + go func() { ch <- p.Run(testStartTimeout) }() + ctx, cancel := context.WithTimeout(context.Background(), testStartTimeout) + defer cancel() + if err := p.WaitReady(ctx); err != nil { + t.Fatalf("WaitReady: %v", err) + } + return ch +} + +func TestProcessCommand_StartStop(t *testing.T) { + skipIfNoSimpleResponder(t) + + cmd, port := simpleResponderCmd(t, "-silent", "-respond hello") + p := newProcessCommand(t, config.ModelConfig{ + Cmd: cmd, + Proxy: fmt.Sprintf("http://127.0.0.1:%d", port), + CheckEndpoint: "/health", + HealthCheckTimeout: 10, + }) + t.Cleanup(func() { p.Stop(testStopTimeout) }) + + req := httptest.NewRequest("GET", "/test", nil) + + // before start: no handler + rr := httptest.NewRecorder() + p.ServeHTTP(rr, req) + if rr.Code != http.StatusServiceUnavailable { + t.Errorf("before start: expected 503, got %d", rr.Code) + } + if body := rr.Body.String(); !strings.Contains(body, "llama-swap-error") { + t.Errorf("before start: expected body to contain %q, got %q", "llama-swap-error", body) + } + + runErr := runAsync(t, p) + if got := p.State(); got != StateReady { + t.Errorf("after Run: expected state %s, got %s", StateReady, got) + } + + rr = httptest.NewRecorder() + p.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Errorf("after Run: expected 200, got %d", rr.Code) + } + if body := rr.Body.String(); body != "hello" { + t.Errorf("expected body %q, got %q", "hello", body) + } + + if err := p.Stop(testStopTimeout); err != nil { + t.Fatalf("Stop() error: %v", err) + } + if got := p.State(); got != StateStopped { + t.Errorf("after Stop: expected state %s, got %s", StateStopped, got) + } + select { + case err := <-runErr: + if err != nil { + t.Errorf("Run() after Stop: expected nil, got %v", err) + } + case <-time.After(testReturnTimeout): + t.Fatal("Run() did not return after Stop") + } + + // after stop: handler cleared + rr = httptest.NewRecorder() + p.ServeHTTP(rr, req) + if rr.Code != http.StatusServiceUnavailable { + t.Errorf("after stop: expected 503, got %d", rr.Code) + } + if body := rr.Body.String(); !strings.Contains(body, "llama-swap-error") { + t.Errorf("after stop: expected body to contain %q, got %q", "llama-swap-error", body) + } +} + +func TestProcessCommand_Run_Idempotent(t *testing.T) { + skipIfNoSimpleResponder(t) + + cmd, port := simpleResponderCmd(t, "-silent") + p := newProcessCommand(t, config.ModelConfig{ + Cmd: cmd, + Proxy: fmt.Sprintf("http://127.0.0.1:%d", port), + CheckEndpoint: "/health", + HealthCheckTimeout: 10, + }) + t.Cleanup(func() { p.Stop(testStopTimeout) }) + + runErr := runAsync(t, p) + + if err := p.Run(testStartTimeout); err == nil { + t.Error("second Run() while running: expected error, got nil") + } + + if err := p.Stop(testStopTimeout); err != nil { + t.Fatalf("Stop() error: %v", err) + } + select { + case <-runErr: + case <-time.After(testReturnTimeout): + t.Fatal("Run() did not return after Stop") + } +} + +func TestProcessCommand_Stop_Idempotent(t *testing.T) { + skipIfNoSimpleResponder(t) + + cmd, port := simpleResponderCmd(t, "-silent") + p := newProcessCommand(t, config.ModelConfig{ + Cmd: cmd, + Proxy: fmt.Sprintf("http://127.0.0.1:%d", port), + CheckEndpoint: "/health", + HealthCheckTimeout: 10, + }) + + if err := p.Stop(testStopTimeout); err != nil { + t.Fatalf("Stop() before Run(): %v", err) + } + + runErr := runAsync(t, p) + + if err := p.Stop(testStopTimeout); err != nil { + t.Fatalf("first Stop() error: %v", err) + } + select { + case <-runErr: + case <-time.After(testReturnTimeout): + t.Fatal("Run() did not return after Stop") + } + + if err := p.Stop(testStopTimeout); err != nil { + t.Fatalf("second Stop() error: %v", err) + } +} + +// TestProcessCommand_StopCancelsRun verifies that a Stop sent while Run is +// executing its health-check loop returns ErrAbort to the Run caller. +// +// A blocking mock HTTP server is used as the proxy so the test can deterministically +// know when doStart is inside the health-check loop before issuing Stop. +func TestProcessCommand_StopCancelsRun(t *testing.T) { + skipIfNoSimpleResponder(t) + + healthCheckStarted := make(chan struct{}, 1) + mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Signal that a health check is in-flight, then block until the client + // cancels (which happens when Stop cancels the start context). + select { + case healthCheckStarted <- struct{}{}: + default: + } + <-r.Context().Done() + http.Error(w, "mock cancelled", http.StatusServiceUnavailable) + })) + defer mock.Close() + + // simple-responder is the real process; health checks go to the blocking mock. + cmd, _ := simpleResponderCmd(t, "-silent") + p := newProcessCommand(t, config.ModelConfig{ + Cmd: cmd, + Proxy: mock.URL, + CheckEndpoint: "/health", + HealthCheckTimeout: 30, + }) + + runErrCh := make(chan error, 1) + go func() { + runErrCh <- p.Run(testStartTimeout) + }() + + // Block until doStart is actually performing a health check, guaranteeing + // that Run is in-flight when Stop is called. + <-healthCheckStarted + + if err := p.Stop(testStopTimeout); err != nil { + t.Fatalf("Stop() error: %v", err) + } + + if err := <-runErrCh; !errors.Is(err, ErrStartAborted) { + t.Errorf("expected ErrStartAborted from Run, got %v", err) + } +} + +// TestProcessCommand_ParentCtxCancelDuringStart verifies that cancelling the +// parent context while doStart is health-checking causes the process to +// transition to StateShutdown promptly, not wait for the health-check timeout. +// +// This is the config-reload race: Stop() returns early when parentCtx is +// already done and never writes to stopCh, so without a parentCtx.Done() +// case in the inner select, the process would keep loading indefinitely. +func TestProcessCommand_ParentCtxCancelDuringStart(t *testing.T) { + skipIfNoSimpleResponder(t) + + healthCheckStarted := make(chan struct{}, 1) + mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + select { + case healthCheckStarted <- struct{}{}: + default: + } + <-r.Context().Done() + http.Error(w, "mock cancelled", http.StatusServiceUnavailable) + })) + defer mock.Close() + + parentCtx, cancelParent := context.WithCancel(context.Background()) + logger := logmon.NewWriter(io.Discard) + cmd, _ := simpleResponderCmd(t, "-silent") + p, err := New(parentCtx, t.Name(), config.ModelConfig{ + Cmd: cmd, + Proxy: mock.URL, + CheckEndpoint: "/health", + HealthCheckTimeout: 60, + }, logger, logger) + if err != nil { + t.Fatalf("New: %v", err) + } + + runErrCh := make(chan error, 1) + go func() { runErrCh <- p.Run(60 * time.Second) }() + + <-healthCheckStarted + + // Cancel parent context to simulate a config reload tearing down the old server. + cancelParent() + + select { + case err := <-runErrCh: + if !strings.Contains(err.Error(), "shutdown") { + t.Errorf("Run error = %v, want shutdown error", err) + } + case <-time.After(5 * time.Second): + t.Fatal("process did not shut down within 5s after parent context cancel during start") + } + + // Run() may return before the run() goroutine writes StateShutdown; + // poll briefly to avoid a spurious race in the assertion. + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if p.State() == StateShutdown { + break + } + time.Sleep(testPollInterval) + } + if got := p.State(); got != StateShutdown { + t.Errorf("after cancel: expected StateShutdown, got %s", got) + } +} + +// TestProcessCommand_RunStopCycle runs several sequential start/stop pairs on +// fresh processes to confirm they are reusable. +func TestProcessCommand_RunStopCycle(t *testing.T) { + skipIfNoSimpleResponder(t) + + for i := range 3 { + cmd, port := simpleResponderCmd(t, "-silent") + p := newProcessCommand(t, config.ModelConfig{ + Cmd: cmd, + Proxy: fmt.Sprintf("http://127.0.0.1:%d", port), + CheckEndpoint: "/health", + HealthCheckTimeout: 10, + }) + + runErr := runAsync(t, p) + + req := httptest.NewRequest("GET", "/health", nil) + rr := httptest.NewRecorder() + p.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Errorf("cycle %d: expected 200 from /health, got %d", i, rr.Code) + } + + if err := p.Stop(testStopTimeout); err != nil { + t.Fatalf("cycle %d Stop() error: %v", i, err) + } + select { + case <-runErr: + case <-time.After(testReturnTimeout): + t.Fatalf("cycle %d: Run() did not return after Stop", i) + } + } +} + +// TestProcessCommand_ReverseProxyPanicIsRecovered drives the full proxy path: +// the upstream responds healthy on /health (so Run completes), then on the +// actual proxied request it hijacks the connection and closes it mid-body. +// That upstream EOF makes httputil.ReverseProxy.copyResponse return an error, +// which panics with http.ErrAbortHandler — the wrapped handlerFn must recover +// and log the disconnect. +// +// Requests are issued through an httptest.NewServer wrapping the process so +// the panic actually fires (httputil only panics on copy errors when the +// request carries http.ServerContextKey, which a real server sets). +// +// see: https://github.com/golang/go/issues/23643 +func TestProcessCommand_ReverseProxyPanicIsRecovered(t *testing.T) { + skipIfNoSimpleResponder(t) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/health" { + w.WriteHeader(http.StatusOK) + return + } + // Send a Content-Length that promises 100 bytes, deliver only a few, + // then slam the connection shut. The reverse proxy will see EOF + // before the body is fully copied and panic with ErrAbortHandler. + hj, ok := w.(http.Hijacker) + if !ok { + t.Errorf("upstream: hijack not supported") + return + } + conn, _, err := hj.Hijack() + if err != nil { + t.Errorf("upstream: hijack: %v", err) + return + } + _, _ = conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 100\r\nContent-Type: text/plain\r\n\r\npartial")) + _ = conn.Close() + })) + t.Cleanup(upstream.Close) + + // Capture proxy log output so we can assert the recover message was + // emitted by handlerFn. + logBuf := &syncBuffer{} + proxyLogger := logmon.NewWriter(logBuf) + procLogger := logmon.NewWriter(io.Discard) + + cmd, _ := simpleResponderCmd(t, "-silent") + p, err := New(context.Background(), t.Name(), config.ModelConfig{ + Cmd: cmd, + Proxy: upstream.URL, + CheckEndpoint: "/health", + HealthCheckTimeout: 10, + }, procLogger, proxyLogger) + if err != nil { + t.Fatalf("New: %v", err) + } + t.Cleanup(func() { p.Stop(testStopTimeout) }) + + _ = runAsync(t, p) + + // Wrap p in an httptest server so requests get http.ServerContextKey + // automatically — that is what makes httputil.ReverseProxy raise the panic. + front := httptest.NewServer(p) + t.Cleanup(front.Close) + + resp, err := http.Get(front.URL + "/disconnect") + if err == nil { + resp.Body.Close() + } + + const want = "recovered from upstream disconnection" + deadline := time.Now().Add(testReturnTimeout) + for time.Now().Before(deadline) { + if strings.Contains(logBuf.String(), want) { + return + } + time.Sleep(testLogPollInterval) + } + t.Errorf("expected proxy log to contain %q; got:\n%s", want, logBuf.String()) +} + +// syncBuffer is a concurrent-safe bytes.Buffer for capturing logmon output. +type syncBuffer struct { + mu sync.Mutex + buf bytes.Buffer +} + +func (b *syncBuffer) Write(p []byte) (int, error) { + b.mu.Lock() + defer b.mu.Unlock() + return b.buf.Write(p) +} + +func (b *syncBuffer) String() string { + b.mu.Lock() + defer b.mu.Unlock() + return b.buf.String() +} + +// TestProcessCommand_TTL_StopsAfterIdle verifies that a process with a TTL +// automatically stops itself after the idle timeout has elapsed following its +// last request. +func TestProcessCommand_TTL_StopsAfterIdle(t *testing.T) { + skipIfNoSimpleResponder(t) + + mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(mock.Close) + + cmd, _ := simpleResponderCmd(t, "-silent") + + cfg := config.ModelConfig{ + Cmd: cmd, + Proxy: mock.URL, + CheckEndpoint: "/health", + HealthCheckTimeout: 10, + UnloadAfter: 1, // 1-second TTL + } + if runtime.GOOS == "windows" { + cfg.CmdStop = "taskkill /f /t /pid ${PID}" + } + + p := newProcessCommand(t, cfg) + + runErr := runAsync(t, p) + defer func() { + if p.State() == StateReady { + p.Stop(testStopTimeout) + } + }() + + if got := p.State(); got != StateReady { + t.Fatalf("expected StateReady, got %s", got) + } + + // Make one request to prime the last-use timestamp. + req := httptest.NewRequest("GET", "/", nil) + rr := httptest.NewRecorder() + p.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Errorf("expected 200 after request, got %d", rr.Code) + } + + // Wait for the TTL goroutine to fire and the process to fully stop. + // Poll for StateStopped directly to avoid racing the StateStopping + // intermediate state that sits between StateReady and StateStopped. + deadline := time.Now().Add(5 * time.Second) + for p.State() != StateStopped && time.Now().Before(deadline) { + time.Sleep(testPollInterval) + } + + if got := p.State(); got != StateStopped { + t.Fatalf("TTL did not stop process; state is %s (expected %s)", got, StateStopped) + } + + // Run() should have returned nil (clean stop from TTL). + select { + case err := <-runErr: + if err != nil { + t.Errorf("Run() after TTL stop: expected nil, got %v", err) + } + case <-time.After(testReturnTimeout): + t.Fatal("Run() did not return after TTL-induced stop") + } +} + +// TestProcessCommand_TTL_ResetsOnRequest verifies that inflight requests +// prevent the TTL goroutine from stopping the process, and that the TTL timer +// resets after each request completes. +func TestProcessCommand_TTL_ResetsOnRequest(t *testing.T) { + skipIfNoSimpleResponder(t) + + mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(mock.Close) + + cmd, _ := simpleResponderCmd(t, "-silent") + p := newProcessCommand(t, config.ModelConfig{ + Cmd: cmd, + Proxy: mock.URL, + CheckEndpoint: "/health", + HealthCheckTimeout: 10, + UnloadAfter: 1, // 1-second TTL + }) + + runErr := runAsync(t, p) + defer func() { + if p.State() == StateReady { + p.Stop(testStopTimeout) + } + }() + + // Keep sending requests for 1.5s — past the 1s TTL — and verify + // the process never stops while traffic is flowing. + stopAt := time.Now().Add(1500 * time.Millisecond) + for time.Now().Before(stopAt) { + req := httptest.NewRequest("GET", "/", nil) + rr := httptest.NewRecorder() + p.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Errorf("expected 200, got %d", rr.Code) + } + if p.State() != StateReady { + t.Fatalf("process was stopped during active traffic (state=%s)", p.State()) + } + time.Sleep(10 * time.Millisecond) + } + + if got := p.State(); got != StateReady { + t.Fatalf("expected StateReady while traffic was active, got %s", got) + } + + // Now stop manually to clean up. + if err := p.Stop(testStopTimeout); err != nil { + t.Fatalf("Stop() error: %v", err) + } + select { + case <-runErr: + case <-time.After(testReturnTimeout): + t.Fatal("Run() did not return after Stop") + } +} + +// TestProcessCommand_TTL_ZeroDisables verifies that UnloadAfter=0 does not +// spawn a TTL goroutine — the process stays ready until explicitly stopped. +func TestProcessCommand_TTL_ZeroDisables(t *testing.T) { + skipIfNoSimpleResponder(t) + + mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(mock.Close) + + cmd, _ := simpleResponderCmd(t, "-silent") + p := newProcessCommand(t, config.ModelConfig{ + Cmd: cmd, + Proxy: mock.URL, + CheckEndpoint: "/health", + HealthCheckTimeout: 10, + UnloadAfter: 0, // disabled + }) + + runErr := runAsync(t, p) + defer func() { + if p.State() == StateReady { + p.Stop(testStopTimeout) + } + }() + + if got := p.State(); got != StateReady { + t.Fatalf("expected StateReady, got %s", got) + } + + req := httptest.NewRequest("GET", "/", nil) + rr := httptest.NewRecorder() + p.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Errorf("expected 200 after request, got %d", rr.Code) + } + + // No TTL goroutine is spawned when UnloadAfter=0, so a brief sleep is + // enough to confirm the process remains ready. + time.Sleep(100 * time.Millisecond) + + if got := p.State(); got != StateReady { + t.Fatalf("process was stopped unexpectedly (state=%s) with TTL=0", got) + } + + // Cleanly stop. + if err := p.Stop(testStopTimeout); err != nil { + t.Fatalf("Stop() error: %v", err) + } + select { + case <-runErr: + case <-time.After(testReturnTimeout): + t.Fatal("Run() did not return after Stop") + } +} + +// TestProcessCommand_ConcurrentRunStop launches many concurrent run/stop racing +// pairs to exercise the race detector and verify no deadlocks occur. +func TestProcessCommand_ConcurrentRunStop(t *testing.T) { + skipIfNoSimpleResponder(t) + + for range 10 { + cmd, port := simpleResponderCmd(t, "-silent") + cfg := config.ModelConfig{ + Cmd: cmd, + Proxy: fmt.Sprintf("http://127.0.0.1:%d", port), + CheckEndpoint: "/health", + HealthCheckTimeout: 10, + } + + if runtime.GOOS == "windows" { + cfg.CmdStop = "taskkill /f /t /pid ${PID}" + } + + p := newProcessCommand(t, cfg) + + runDone := make(chan struct{}) + go func() { + defer close(runDone) + p.Run(testStartTimeout) //nolint: errcheck — one goroutine wins the race + }() + go func() { + p.Stop(testStopTimeout) //nolint: errcheck + }() + + // Backstop: the racing Stop may have arrived before Run got on the + // channel (making it a no-op), so keep stopping until Run unblocks. + deadline := time.After(testStartTimeout) + for done := false; !done; { + select { + case <-runDone: + done = true + case <-deadline: + t.Fatal("Run did not return") + case <-time.After(testPollInterval): + p.Stop(testStopTimeout) //nolint: errcheck + } + } + } +} diff --git a/internal/process/process_events_test.go b/internal/process/process_events_test.go new file mode 100644 index 0000000..f460c66 --- /dev/null +++ b/internal/process/process_events_test.go @@ -0,0 +1,82 @@ +package process + +import ( + "fmt" + "sync" + "testing" + "time" + + "github.com/mostlygeek/llama-swap/internal/config" + "github.com/mostlygeek/llama-swap/internal/event" + "github.com/mostlygeek/llama-swap/internal/shared" +) + +func TestProcessCommand_EmitsStateChangeEvents(t *testing.T) { + skipIfNoSimpleResponder(t) + + var mu sync.Mutex + var transitions []shared.ProcessStateChangeEvent + cancel := event.On(func(e shared.ProcessStateChangeEvent) { + if e.ProcessName != t.Name() { + return + } + mu.Lock() + transitions = append(transitions, e) + mu.Unlock() + }) + defer cancel() + + cmd, port := simpleResponderCmd(t, "-silent", "-respond hello") + p := newProcessCommand(t, config.ModelConfig{ + Cmd: cmd, + Proxy: fmt.Sprintf("http://127.0.0.1:%d", port), + CheckEndpoint: "/health", + HealthCheckTimeout: 10, + }) + + runErr := runAsync(t, p) + if err := p.Stop(testStopTimeout); err != nil { + t.Fatalf("Stop: %v", err) + } + <-runErr + + // Events are delivered asynchronously; give the dispatcher a moment. + deadline := time.Now().Add(time.Second) + for time.Now().Before(deadline) { + mu.Lock() + n := len(transitions) + mu.Unlock() + if n >= 4 { + break + } + time.Sleep(testPollInterval) + } + + mu.Lock() + defer mu.Unlock() + + for _, e := range transitions { + if e.OldState == e.NewState { + t.Errorf("emitted no-op transition: %s -> %s", e.OldState, e.NewState) + } + } + + want := []string{ + string(StateStopped) + "->" + string(StateStarting), + string(StateStarting) + "->" + string(StateReady), + string(StateReady) + "->" + string(StateStopping), + string(StateStopping) + "->" + string(StateStopped), + } + got := make([]string, len(transitions)) + for i, e := range transitions { + got[i] = e.OldState + "->" + e.NewState + } + if len(got) != len(want) { + t.Fatalf("transitions = %v, want %v", got, want) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("transitions = %v, want %v", got, want) + } + } +} diff --git a/internal/process/runtime_unix.go b/internal/process/runtime_unix.go new file mode 100644 index 0000000..f7c5adc --- /dev/null +++ b/internal/process/runtime_unix.go @@ -0,0 +1,12 @@ +//go:build !windows + +package process + +import ( + "os/exec" +) + +// setProcAttributes sets platform-specific process attributes +func setProcAttributes(cmd *exec.Cmd) { + // No-op on Unix systems +} diff --git a/internal/process/runtime_windows.go b/internal/process/runtime_windows.go new file mode 100644 index 0000000..3888db8 --- /dev/null +++ b/internal/process/runtime_windows.go @@ -0,0 +1,16 @@ +//go:build windows + +package process + +import ( + "os/exec" + "syscall" +) + +// setProcAttributes sets platform-specific process attributes +func setProcAttributes(cmd *exec.Cmd) { + cmd.SysProcAttr = &syscall.SysProcAttr{ + HideWindow: true, + CreationFlags: 0x08000000, // CREATE_NO_WINDOW + } +} diff --git a/internal/router/base.go b/internal/router/base.go new file mode 100644 index 0000000..33ff58a --- /dev/null +++ b/internal/router/base.go @@ -0,0 +1,775 @@ +package router + +import ( + "context" + "fmt" + "net/http" + "sync" + "sync/atomic" + "time" + + "github.com/mostlygeek/llama-swap/internal/config" + "github.com/mostlygeek/llama-swap/internal/logmon" + "github.com/mostlygeek/llama-swap/internal/process" +) + +type shutdownReq struct { + timeout time.Duration + respond chan error +} + +type unloadReq struct { + targets []string + timeout time.Duration + respond chan struct{} +} + +type handlerReq struct { + model string + ctx context.Context + respond chan handlerResp + positionCh chan int +} + +type handlerResp struct { + handleFunc http.HandlerFunc + err error +} + +type swapDone struct { + modelID string + err error +} + +type serveDoneEvent struct { + modelID string +} + +type activeSwap struct { + modelID string + evict []string + waiters []handlerReq +} + +// swapPlanner is the only piece of behaviour that differs between concrete +// routers. baseRouter never inspects its internals. +type swapPlanner interface { + // EvictionFor returns running model IDs that must be stopped before + // target can serve. alsoRunning lists models the baseRouter has already + // committed to loading (in-flight swaps) which the planner cannot see + // via process.State() yet. Pure decision; must not log. + EvictionFor(target string, alsoRunning []string) []string + + // OnSwapStart runs once at the start of every swap. Planners may log + // their decision here at whatever verbosity they choose. + OnSwapStart(target string) +} + +// baseRouter owns the channels, run-loop, and orchestration code shared by +// every concrete router. Concrete routers embed *baseRouter and supply a +// swapPlanner that captures how their eviction set is decided. +type baseRouter struct { + name string + config config.Config + processes map[string]process.Process + logger *logmon.Monitor + planner swapPlanner + + shutdownCtx context.Context + shutdownFn context.CancelFunc + shuttingDown atomic.Bool + + handlerCh chan handlerReq + shutdownCh chan shutdownReq + unloadCh chan unloadReq + swapDoneCh chan swapDone + serveDoneCh chan serveDoneEvent + + runDone chan struct{} + + // testProcessed, when non-nil, receives one event after each handlerReq + // or swapDone has been fully processed by run(). Tests use it to wait + // for run() to reach a deterministic state without sleeping. serveDone + // events are intentionally NOT signalled here so test event counts + // remain stable. + testProcessed chan struct{} +} + +func newBaseRouter(name string, conf config.Config, processes map[string]process.Process, planner swapPlanner, logger *logmon.Monitor) *baseRouter { + shutdownCtx, shutdownFn := context.WithCancel(context.Background()) + return &baseRouter{ + name: name, + config: conf, + processes: processes, + logger: logger, + planner: planner, + shutdownCtx: shutdownCtx, + shutdownFn: shutdownFn, + handlerCh: make(chan handlerReq), + shutdownCh: make(chan shutdownReq), + unloadCh: make(chan unloadReq), + swapDoneCh: make(chan swapDone), + serveDoneCh: make(chan serveDoneEvent), + runDone: make(chan struct{}), + } +} + +func (b *baseRouter) notifyProcessed() { + if b.testProcessed != nil { + b.testProcessed <- struct{}{} + } +} + +func (b *baseRouter) run() { + defer close(b.runDone) + + active := make(map[string]*activeSwap) + inFlight := make(map[string]int) + var queued []handlerReq + + for { + select { + case req := <-b.shutdownCh: + b.handleShutdown(req, active, queued) + return + + case req := <-b.handlerCh: + b.handleRequest(req, active, inFlight, &queued) + b.notifyProcessed() + + case req := <-b.unloadCh: + b.handleUnload(req, active, inFlight, &queued) + b.notifyProcessed() + + case ev := <-b.swapDoneCh: + b.handleSwapDone(ev, active, inFlight, &queued) + b.notifyProcessed() + + case ev := <-b.serveDoneCh: + b.handleServeDone(ev, active, inFlight, &queued) + } + } +} + +// grant sends a response back to the caller of ServeHTTP and tells us +// whether the caller was still there to receive it. +// +// Each ServeHTTP creates a fresh, UNBUFFERED respond channel and parks in +// a select waiting on it. "Unbuffered" is the important word: a send only +// completes when the other side is actively receiving. So if this send +// succeeds, we know for a fact the caller picked up the response and will +// act on it. If the caller has already given up (its request context was +// cancelled, e.g. the HTTP client disconnected) or the router is shutting +// down, the send never lands, one of the other select cases fires, and we +// report back that the grant did NOT happen. +// +// That distinction matters for in-flight bookkeeping — see grantHandler. +func (b *baseRouter) grant(req handlerReq, resp handlerResp) bool { + select { + case req.respond <- resp: + return true + case <-req.ctx.Done(): + return false + case <-b.shutdownCtx.Done(): + return false + } +} + +// grantHandler is the "this caller can now use process p" path. It does +// two things that must stay locked together: +// +// 1. Hand the caller a wrapped p.ServeHTTP (via trackedServe) so when the +// HTTP request finishes, the run loop hears about it. +// 2. Bump inFlight[modelID] so the router knows this process is busy and +// refuses to evict it until the count comes back down. +// +// The increment is gated on grant() returning true. If grant() returns +// false, the caller already walked away and trackedServe will never run — +// which means no matching decrement will ever arrive on serveDoneCh. +// Incrementing in that case would strand the counter at >0 forever and +// the router would never again be willing to swap this model out. +// +// In short: increment if and only if we know a decrement is coming. +func (b *baseRouter) grantHandler(req handlerReq, modelID string, p process.Process, inFlight map[string]int) { + if b.grant(req, handlerResp{handleFunc: b.trackedServe(modelID, p)}) { + inFlight[modelID]++ + } +} + +// trackedServe is the wrapper that closes the loop on in-flight tracking. +// It runs p.ServeHTTP normally; the only added behaviour is a deferred +// send on serveDoneCh after the handler returns. That send is what tells +// the run loop "this model now has one fewer request in flight — go look +// at the queue again, you may be able to start a swap you previously had +// to defer." +// +// The select on shutdownCtx.Done() is a release valve: if the router is +// already shutting down, nobody is reading serveDoneCh, so we drop the +// notification rather than blocking the HTTP goroutine forever. +func (b *baseRouter) trackedServe(modelID string, p process.Process) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + defer func() { + select { + case b.serveDoneCh <- serveDoneEvent{modelID: modelID}: + case <-b.shutdownCtx.Done(): + } + }() + p.ServeHTTP(w, r) + } +} + +// handleRequest decides what to do with one incoming ServeHTTP request. It is +// called from run() and never blocks indefinitely: any work that has to wait +// (starting a process, stopping siblings, waiting for ready) is deferred to +// a swap goroutine and reported back via swapDoneCh. +// +// The decision tree, in order: +// +// 1. Unknown model — respond with ErrNoLocalModelFound and move on. +// 2. A swap to the same model is already in flight — attach this waiter so +// one swap serves all callers that asked for the same model. +// 3. Fast path — the target process is already ready, the planner sees +// nothing to evict, and no in-flight swap is evicting it. Hand back its +// ServeHTTP immediately (wrapped so the run loop knows when it ends). +// 4. Would collide with an in-flight swap (we'd stop their target, or +// they're stopping us) — park in the queue for handleSwapDone to drain. +// 5. Would evict a process that is still handling requests — park in the +// queue. handleServeDone will retry when the busy process drains. +// 6. Otherwise — start a new swap. This may run in parallel with other +// active swaps when their evict sets don't intersect. +func (b *baseRouter) handleRequest(req handlerReq, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) { + // (1) Unknown model. + p, ok := b.processes[req.model] + if !ok { + b.logger.Debugf("%s: model %s not handled by this router", b.name, req.model) + b.grant(req, handlerResp{err: ErrNoLocalModelFound}) + return + } + + // (2) Join an in-flight swap for the same model. + if s, ok := active[req.model]; ok { + b.logger.Debugf("%s: joining in-flight swap for model %s (%d waiters)", b.name, req.model, len(s.waiters)+1) + s.waiters = append(s.waiters, req) + return + } + + evict := b.planner.EvictionFor(req.model, activeTargets(active, req.model)) + + // (3) Fast path: ready, nothing to evict, and nobody is evicting us. + if p.State() == process.StateReady && len(evict) == 0 && !collidesWith(req.model, evict, active) { + b.logger.Debugf("%s: fast-path serving model %s (already ready)", b.name, req.model) + b.grantHandler(req, req.model, p, inFlight) + return + } + + // (4) Collision with an in-flight swap — queue. + if collidesWith(req.model, evict, active) { + b.logger.Debugf("%s: queuing request for model %s (collides with in-flight swap)", b.name, req.model) + *queued = append(*queued, req) + b.broadcastQueuePositions(*queued) + return + } + + // (5) Would evict a busy process — queue until it drains. + if conflictsWithInFlight(evict, inFlight) { + b.logger.Debugf("%s: queuing request for model %s (would evict in-flight process)", b.name, req.model) + *queued = append(*queued, req) + b.broadcastQueuePositions(*queued) + return + } + + // (6) Start a new (possibly parallel) swap. + b.logger.Debugf("%s: starting swap for model %s, evicting %v", b.name, req.model, evict) + s := b.startSwap(req, evict) + active[s.modelID] = s +} + +// handleSwapDone is called from run() when a swap goroutine reports that it +// has finished. It fans out the result to every waiter that joined this swap, +// removes the swap from the active map, and then walks the queue once, +// promoting any items that no longer collide with the remaining active set. +// FIFO order is preserved: items still blocked stay in place. +func (b *baseRouter) handleSwapDone(ev swapDone, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) { + s, ok := active[ev.modelID] + if !ok { + return + } + delete(active, ev.modelID) + + for _, w := range s.waiters { + if ev.err != nil { + b.grant(w, handlerResp{err: ev.err}) + } else { + p := b.processes[ev.modelID] + b.grantHandler(w, ev.modelID, p, inFlight) + } + } + + b.drainQueue(active, inFlight, queued) +} + +// handleServeDone is called from run() each time a tracked ServeHTTP +// finishes. It decrements the per-model in-flight count and, when that +// drops to zero, retries the queue: requests whose swap was deferred +// because they would have evicted this (now-idle) process can now proceed. +func (b *baseRouter) handleServeDone(ev serveDoneEvent, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) { + inFlight[ev.modelID]-- + if inFlight[ev.modelID] <= 0 { + delete(inFlight, ev.modelID) + b.drainQueue(active, inFlight, queued) + } +} + +// drainQueue walks the queued requests in order, re-running the handleRequest +// decision tree against the (now smaller) active set. Items that can now start +// or join become satisfied; items still blocked remain queued in original +// order so they get another chance on the next swap completion. +func (b *baseRouter) drainQueue(active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) { + if len(*queued) == 0 { + return + } + pending := *queued + var remaining []handlerReq + for _, req := range pending { + p, ok := b.processes[req.model] + if !ok { + b.grant(req, handlerResp{err: ErrNoLocalModelFound}) + continue + } + if s, ok := active[req.model]; ok { + b.logger.Debugf("%s: queued request for model %s now joining in-flight swap", b.name, req.model) + s.waiters = append(s.waiters, req) + continue + } + evict := b.planner.EvictionFor(req.model, activeTargets(active, req.model)) + if p.State() == process.StateReady && len(evict) == 0 && !collidesWith(req.model, evict, active) { + b.logger.Debugf("%s: queued request for model %s now served fast-path", b.name, req.model) + b.grantHandler(req, req.model, p, inFlight) + continue + } + if collidesWith(req.model, evict, active) { + remaining = append(remaining, req) + continue + } + if conflictsWithInFlight(evict, inFlight) { + remaining = append(remaining, req) + continue + } + b.logger.Debugf("%s: queued request for model %s now starting swap, evicting %v", b.name, req.model, evict) + s := b.startSwap(req, evict) + active[s.modelID] = s + } + *queued = remaining + b.broadcastQueuePositions(*queued) +} + +// broadcastQueuePositions sends each queued request its current 1-indexed +// position. Sends are non-blocking: if the channel is full, the old value is +// drained first so the consumer always sees the latest position. +func (b *baseRouter) broadcastQueuePositions(queued []handlerReq) { + for i, req := range queued { + pos := i + 1 + select { + case req.positionCh <- pos: + default: + select { + case <-req.positionCh: + default: + } + select { + case req.positionCh <- pos: + default: + } + } + } +} + +func (b *baseRouter) startSwap(initial handlerReq, evict []string) *activeSwap { + swap := &activeSwap{ + modelID: initial.model, + evict: evict, + waiters: []handlerReq{initial}, + } + b.planner.OnSwapStart(initial.model) + go b.doSwap(initial.model, evict) + return swap +} + +// activeTargets returns the IDs of every in-flight swap target except exclude. +// baseRouter passes this to the planner so eviction decisions account for +// models that have been committed to but have not yet transitioned to +// StateStarting in their process state machine. +func activeTargets(active map[string]*activeSwap, exclude string) []string { + if len(active) == 0 { + return nil + } + out := make([]string, 0, len(active)) + for id := range active { + if id == exclude { + continue + } + out = append(out, id) + } + return out +} + +// collidesWith reports whether a new swap with this target and evict set can +// safely run alongside the currently active swaps. Same-target callers should +// JOIN (handled before this) — they do not collide with themselves. +func collidesWith(target string, evict []string, active map[string]*activeSwap) bool { + for id, s := range active { + if id == target { + continue + } + if containsString(evict, id) { + return true + } + if containsString(s.evict, target) { + return true + } + } + return false +} + +// conflictsWithInFlight reports whether any model in evict is still handling +// requests. Stopping a busy process would cancel its callers' connections, +// so the router defers the swap until those callers finish. +func conflictsWithInFlight(evict []string, inFlight map[string]int) bool { + for _, m := range evict { + if inFlight[m] > 0 { + return true + } + } + return false +} + +func containsString(xs []string, s string) bool { + for _, x := range xs { + if x == s { + return true + } + } + return false +} + +func (b *baseRouter) doSwap(modelID string, toStop []string) { + timeout := b.healthCheckTimeout() + + var wg sync.WaitGroup + for _, mID := range toStop { + wg.Add(1) + go func(p process.Process, id string) { + defer wg.Done() + if err := p.Stop(timeout); err != nil { + b.logger.Warnf("%s: stopping %s failed: %v", b.name, id, err) + } + }(b.processes[mID], mID) + } + wg.Wait() + + target := b.processes[modelID] + if target.State() == process.StateStopped { + go func() { + if err := target.Run(timeout); err != nil { + b.logger.Warnf("%s: running %s exited: %v", b.name, modelID, err) + } + }() + } + + err := target.WaitReady(b.shutdownCtx) + + select { + case b.swapDoneCh <- swapDone{modelID: modelID, err: err}: + case <-b.shutdownCtx.Done(): + } +} + +func (b *baseRouter) handleShutdown(req shutdownReq, active map[string]*activeSwap, queued []handlerReq) { + shutdownErr := fmt.Errorf("%s is shutting down", b.name) + + // Cancel shutdownCtx first so any waiter that is currently parked on + // its respond channel can exit via its own shutdownCtx.Done() branch. + // The grant calls below then either land (waiter happened to receive + // before noticing shutdown) or fall through immediately via grant's + // shutdownCtx case — either way the waiter sees a non-OK response. + b.shutdownFn() + + for _, s := range active { + for _, w := range s.waiters { + b.grant(w, handlerResp{err: shutdownErr}) + } + } + for _, w := range queued { + b.grant(w, handlerResp{err: shutdownErr}) + } + + stopTimeout := req.timeout + if stopTimeout <= 0 { + stopTimeout = b.healthCheckTimeout() + } + + var wg sync.WaitGroup + for i, p := range b.processes { + wg.Add(1) + go func(id string, p process.Process) { + defer wg.Done() + if err := p.Stop(stopTimeout); err != nil { + b.logger.Warnf("%s failed to stop process %s: %v", b.name, id, err) + } + }(i, p) + } + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + if req.timeout > 0 { + select { + case <-done: + case <-time.After(req.timeout): + <-done + } + } else { + <-done + } + + req.respond <- nil +} + +func (b *baseRouter) healthCheckTimeout() time.Duration { + t := time.Duration(b.config.HealthCheckTimeout) * time.Second + if t <= 0 { + return 30 * time.Second + } + return t +} + +func (b *baseRouter) Handles(model string) bool { + _, ok := b.processes[model] + return ok +} + +func (b *baseRouter) ProcessLogger(modelID string) (*logmon.Monitor, bool) { + if p, ok := b.processes[modelID]; ok { + return p.Logger(), true + } + return nil, false +} + +// RunningModels returns the current state of every process that is not stopped +// or shut down. The processes map keys are fixed at construction and State() +// is a snapshot, so this is safe to call without the run loop. +func (b *baseRouter) RunningModels() map[string]process.ProcessState { + running := make(map[string]process.ProcessState) + for id, p := range b.processes { + st := p.State() + if st == process.StateStopped || st == process.StateShutdown { + continue + } + running[id] = st + } + return running +} + +// Unload stops the named models, or every running model when none are named. +// It blocks until each targeted process has stopped. +// +// The request is funneled through the run loop so eviction is coordinated +// with the rest of the router's state: pending swap waiters for an +// unloaded model are released with an error, queued requests for unloaded +// models are dropped, and any deferred swaps that were waiting on those +// models become eligible to start. +// +// In-flight requests being served by an unloaded process are not waited +// for — Stop kills the upstream, those callers see whatever error the +// reverse proxy surfaces and may retry. Their trackedServe defers fire +// normally and decrement inFlight as the dying handlers return. +func (b *baseRouter) Unload(timeout time.Duration, models ...string) { + targets := models + if len(targets) == 0 { + targets = make([]string, 0, len(b.processes)) + for id := range b.processes { + targets = append(targets, id) + } + } + if len(targets) == 0 { + return + } + + req := unloadReq{targets: targets, timeout: timeout, respond: make(chan struct{})} + select { + case b.unloadCh <- req: + case <-b.runDone: + return + } + <-req.respond +} + +// handleUnload runs on the run loop in response to an Unload call. It +// reconciles router-owned state with the impending Stop, then performs +// the Stop synchronously so callers of Unload remain blocked until each +// targeted process has actually exited. +func (b *baseRouter) handleUnload(req unloadReq, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) { + unloadErr := fmt.Errorf("%s: model unloaded", b.name) + + targetSet := make(map[string]bool, len(req.targets)) + for _, id := range req.targets { + targetSet[id] = true + } + + // Release waiters of any in-flight swap whose target is being + // unloaded. The swap goroutine itself is left to finish on its own; + // when its swapDone arrives, handleSwapDone will find no entry in + // active and silently drop it. + for id := range targetSet { + s, ok := active[id] + if !ok { + continue + } + for _, w := range s.waiters { + b.grant(w, handlerResp{err: unloadErr}) + } + delete(active, id) + } + + // Drop queued requests addressed to unloaded models. Requests for + // other models stay queued and may benefit from drainQueue at the end. + if len(*queued) > 0 { + kept := (*queued)[:0] + for _, w := range *queued { + if targetSet[w.model] { + b.grant(w, handlerResp{err: unloadErr}) + continue + } + kept = append(kept, w) + } + *queued = kept + } + + // Stop the targeted processes. Done synchronously so Unload's caller + // can rely on "after Unload returns, the process is stopped". inFlight + // is intentionally NOT cleared here: each dying handler will fire its + // trackedServe defer and reach handleServeDone in the normal way once + // the run loop is free again. + var wg sync.WaitGroup + for id := range targetSet { + p, ok := b.processes[id] + if !ok { + continue + } + wg.Add(1) + go func(id string, p process.Process) { + defer wg.Done() + if err := p.Stop(req.timeout); err != nil { + b.logger.Warnf("%s: unloading %s failed: %v", b.name, id, err) + } + }(id, p) + } + wg.Wait() + + // Removing entries from active above may have unblocked queued + // requests that previously collided with the now-cancelled swaps. + b.drainQueue(active, inFlight, queued) + + close(req.respond) +} + +func (b *baseRouter) Shutdown(timeout time.Duration) error { + if !b.shuttingDown.CompareAndSwap(false, true) { + return fmt.Errorf("%s shutdown already in progress", b.name) + } + req := shutdownReq{timeout: timeout, respond: make(chan error, 1)} + select { + case b.shutdownCh <- req: + case <-b.runDone: + return nil + } + return <-req.respond +} + +func (b *baseRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if b.shuttingDown.Load() { + SendError(w, req, fmt.Errorf("%s is shutting down", b.name)) + return + } + + data, err := FetchContext(req, b.config) + if err != nil { + SendError(w, req, err) + return + } + + hr := handlerReq{ + model: data.ModelID, + ctx: req.Context(), + // Unbuffered: a successful send on respond proves the waiter is + // alive and consuming. grant() relies on this to avoid handing a + // handleFunc to a cancelled waiter and leaking the inFlight count. + respond: make(chan handlerResp), + positionCh: make(chan int, 1), + } + + select { + case b.handlerCh <- hr: + case <-req.Context().Done(): + return + case <-b.shutdownCtx.Done(): + SendError(w, req, fmt.Errorf("%s is shutting down", b.name)) + return + } + + isModelReady := false + if p, ok := b.processes[data.ModelID]; ok { + isModelReady = p.State() == process.StateReady + } + shouldShowLoading := data.Streaming && data.SendLoadingState && isLoadingPath(req.URL.Path) && !isModelReady + + var lw *loadingWriter + cancelLoad := func() {} + if shouldShowLoading { + var swapCtx context.Context + swapCtx, cancelLoad = context.WithCancel(req.Context()) + lw = newLoadingWriter(b.logger, data.ModelID, w, req) + go lw.start(swapCtx) + go func() { + for { + select { + case pos := <-hr.positionCh: + lw.setUpdate(fmt.Sprintf("Queue position: #%d", pos)) + case <-swapCtx.Done(): + return + } + } + }() + } + + var resp handlerResp + select { + case resp = <-hr.respond: + cancelLoad() + if lw != nil { + lw.waitForCompletion(1 * time.Second) + } + case <-req.Context().Done(): + cancelLoad() + if lw != nil { + lw.waitForCompletion(1 * time.Second) + } + return + case <-b.shutdownCtx.Done(): + cancelLoad() + if lw != nil { + lw.waitForCompletion(1 * time.Second) + } + SendError(w, req, fmt.Errorf("%s is shutting down", b.name)) + return + } + + if resp.err != nil { + SendError(w, req, resp.err) + return + } + resp.handleFunc(w, req) +} diff --git a/internal/router/base_test.go b/internal/router/base_test.go new file mode 100644 index 0000000..3f63f93 --- /dev/null +++ b/internal/router/base_test.go @@ -0,0 +1,863 @@ +package router + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/mostlygeek/llama-swap/internal/config" + "github.com/mostlygeek/llama-swap/internal/logmon" + "github.com/mostlygeek/llama-swap/internal/process" +) + +// stubPlanner is a swapPlanner that returns a fixed eviction list per target +// and never logs. It lets the base-router tests cover shared run-loop +// behaviour without dragging in either real router's eviction rules. +type stubPlanner struct { + evict map[string][]string +} + +func (s *stubPlanner) EvictionFor(target string, _ []string) []string { + if s.evict == nil { + return nil + } + return s.evict[target] +} + +func (s *stubPlanner) OnSwapStart(string) {} + +func newTestBase(t *testing.T, processes map[string]process.Process, planner swapPlanner) *baseRouter { + t.Helper() + conf := config.Config{HealthCheckTimeout: 5} + b := newBaseRouter("test", conf, processes, planner, logmon.NewWriter(io.Discard)) + b.testProcessed = make(chan struct{}, 64) + go b.run() + t.Cleanup(func() { + if !b.shuttingDown.Load() { + _ = b.Shutdown(time.Second) + } + }) + return b +} + +func TestBaseRouter_RunningModels(t *testing.T) { + ready := newFakeProcess("ready") + ready.markReady() + starting := newFakeProcess("starting") + starting.setState(process.StateStarting) + stopped := newFakeProcess("stopped") + + b := newTestBase(t, map[string]process.Process{ + "ready": ready, "starting": starting, "stopped": stopped, + }, &stubPlanner{}) + + running := b.RunningModels() + if len(running) != 2 { + t.Fatalf("running=%v want 2 entries", running) + } + if running["ready"] != process.StateReady { + t.Errorf("ready state=%q want ready", running["ready"]) + } + if running["starting"] != process.StateStarting { + t.Errorf("starting state=%q want starting", running["starting"]) + } + if _, ok := running["stopped"]; ok { + t.Errorf("stopped process should be excluded from RunningModels") + } +} + +func TestBaseRouter_UnloadAll(t *testing.T) { + a := newFakeProcess("a") + a.markReady() + c := newFakeProcess("c") + c.markReady() + + b := newTestBase(t, map[string]process.Process{"a": a, "c": c}, &stubPlanner{}) + b.Unload(time.Second) + + if a.State() != process.StateStopped || c.State() != process.StateStopped { + t.Fatalf("Unload() should stop every process: a=%q c=%q", a.State(), c.State()) + } +} + +func TestBaseRouter_UnloadSpecificModel(t *testing.T) { + a := newFakeProcess("a") + a.markReady() + c := newFakeProcess("c") + c.markReady() + + b := newTestBase(t, map[string]process.Process{"a": a, "c": c}, &stubPlanner{}) + b.Unload(time.Second, "a") + + if a.State() != process.StateStopped { + t.Errorf("a should be stopped, got %q", a.State()) + } + if c.State() != process.StateReady { + t.Errorf("c should remain ready, got %q", c.State()) + } +} + +// TestBaseRouter_Unload_StopsInParallel verifies that Unload fans out its +// Stop calls concurrently rather than stopping each process serially. Each +// fakeProcess.Stop is pinned via stopBlock; the test only releases them +// after observing every stopStarted, proving all three Stops were in +// flight simultaneously. +func TestBaseRouter_Unload_StopsInParallel(t *testing.T) { + a := newFakeProcess("a") + a.markReady() + a.stopBlock = make(chan struct{}) + pb := newFakeProcess("b") + pb.markReady() + pb.stopBlock = make(chan struct{}) + pc := newFakeProcess("c") + pc.markReady() + pc.stopBlock = make(chan struct{}) + + b := newTestBase(t, map[string]process.Process{"a": a, "b": pb, "c": pc}, &stubPlanner{}) + + unloadDone := make(chan struct{}) + go func() { + b.Unload(time.Second, "a", "b", "c") + close(unloadDone) + }() + + // All three Stop calls must start before any of them are allowed to + // complete. If Unload was serial, only one stopStarted would fire + // until we released its stopBlock, and this would deadlock. + for _, p := range []*fakeProcess{a, pb, pc} { + select { + case <-p.stopStarted: + case <-time.After(2 * time.Second): + t.Fatalf("Stop on %s never started — Unload is not parallel", p.id) + } + } + + // Release them; Unload should now return. + close(a.stopBlock) + close(pb.stopBlock) + close(pc.stopBlock) + + select { + case <-unloadDone: + case <-time.After(2 * time.Second): + t.Fatal("Unload did not return after stops released") + } + + for _, p := range []*fakeProcess{a, pb, pc} { + if p.State() != process.StateStopped { + t.Errorf("%s state=%q want stopped", p.id, p.State()) + } + if got := p.stopCalls.Load(); got != 1 { + t.Errorf("%s stopCalls=%d want 1", p.id, got) + } + } +} + +// TestBaseRouter_Unload_ReleasesActiveSwapWaiters verifies that Unload +// rejoins router state: a request whose swap to the unloaded model is +// still in progress receives an error, instead of being abandoned +// against a process that's about to vanish. +func TestBaseRouter_Unload_ReleasesActiveSwapWaiters(t *testing.T) { + a := newFakeProcess("a") + // autoReady=false: the swap parks on WaitReady so we can interrupt + // it with Unload before it completes. + + b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{}) + + w := httptest.NewRecorder() + done := make(chan struct{}) + go func() { + b.ServeHTTP(w, newRequest("a")) + close(done) + }() + waitProcessed(t, b.testProcessed, 1) // handlerReq absorbed; swap started + <-a.runStarted + + b.Unload(time.Second, "a") + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("ServeHTTP did not return after Unload") + } + if w.Code == http.StatusOK { + t.Errorf("expected non-OK status after Unload, got %d body=%q", w.Code, w.Body.String()) + } + if a.State() != process.StateStopped { + t.Errorf("a state=%q want stopped", a.State()) + } +} + +// TestBaseRouter_Unload_DropsQueuedRequests verifies that queued requests +// for an unloaded model receive an error rather than sitting forever in +// the queue against state the router no longer maintains. +func TestBaseRouter_Unload_DropsQueuedRequests(t *testing.T) { + a := newFakeProcess("a") + pb := newFakeProcess("b") + // Loading B evicts A — so a request for B while A is loading queues. + planner := &stubPlanner{evict: map[string][]string{"b": {"a"}}} + b := newTestBase(t, map[string]process.Process{"a": a, "b": pb}, planner) + + // r1 starts the swap to A and parks on WaitReady (autoReady=false). + w1 := httptest.NewRecorder() + done1 := make(chan struct{}) + go func() { + b.ServeHTTP(w1, newRequest("a")) + close(done1) + }() + waitProcessed(t, b.testProcessed, 1) + <-a.runStarted + + // r2 for B collides with A's in-flight swap and queues. + w2 := httptest.NewRecorder() + done2 := make(chan struct{}) + go func() { + b.ServeHTTP(w2, newRequest("b")) + close(done2) + }() + waitProcessed(t, b.testProcessed, 1) + + // Unload B — r2 (queued, targeting B) must be released with an error. + b.Unload(time.Second, "b") + + select { + case <-done2: + case <-time.After(2 * time.Second): + t.Fatal("queued B request did not return after Unload(b)") + } + if w2.Code == http.StatusOK { + t.Errorf("queued B request: expected non-OK status, got %d", w2.Code) + } + if got := pb.runCalls.Load(); got != 0 { + t.Errorf("b.runCalls=%d want 0 (B should never have been started)", got) + } + + // Release r1 so the test cleans up cleanly. + a.markReady() + select { + case <-done1: + case <-time.After(2 * time.Second): + t.Fatal("r1 did not complete after a.markReady") + } +} + +func TestBaseRouter_FastPath(t *testing.T) { + a := newFakeProcess("a") + a.markReady() + + b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{}) + + w := httptest.NewRecorder() + b.ServeHTTP(w, newRequest("a")) + + if w.Code != http.StatusOK { + t.Fatalf("status=%d body=%q", w.Code, w.Body.String()) + } + if got := a.serveCalls.Load(); got != 1 { + t.Errorf("serveCalls=%d want 1", got) + } + if got := a.runCalls.Load(); got != 0 { + t.Errorf("runCalls=%d want 0 (fast path should not start)", got) + } +} + +func TestBaseRouter_OnDemandStart(t *testing.T) { + a := newFakeProcess("a") + a.autoReady = true + + b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{}) + + w := httptest.NewRecorder() + b.ServeHTTP(w, newRequest("a")) + + if w.Code != http.StatusOK { + t.Fatalf("status=%d body=%q", w.Code, w.Body.String()) + } + if got := a.runCalls.Load(); got != 1 { + t.Errorf("runCalls=%d want 1", got) + } + if got := a.serveCalls.Load(); got != 1 { + t.Errorf("serveCalls=%d want 1", got) + } +} + +func TestBaseRouter_ConcurrentSameModel(t *testing.T) { + a := newFakeProcess("a") + // autoReady=false so the swap parks on WaitReady until we release it. + + b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{}) + + const N = 5 + var wg sync.WaitGroup + codes := make([]int, N) + for i := 0; i < N; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + w := httptest.NewRecorder() + b.ServeHTTP(w, newRequest("a")) + codes[i] = w.Code + }(i) + } + + waitProcessed(t, b.testProcessed, N) // all N handlerReqs absorbed by run() + <-a.runStarted // swap goroutine reached Run() + a.markReady() + wg.Wait() + + for i, c := range codes { + if c != http.StatusOK { + t.Errorf("request %d: status=%d", i, c) + } + } + if got := a.runCalls.Load(); got != 1 { + t.Errorf("runCalls=%d want 1 (single swap should issue one Run)", got) + } + if got := a.serveCalls.Load(); got != N { + t.Errorf("serveCalls=%d want %d", got, N) + } +} + +func TestBaseRouter_ContextCancel(t *testing.T) { + a := newFakeProcess("a") + // autoReady=false so swap parks forever until we mark ready. + + b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{}) + + ctx, cancel := context.WithCancel(context.Background()) + w1 := httptest.NewRecorder() + done1 := make(chan struct{}) + go func() { + b.ServeHTTP(w1, newRequestCtx(ctx, "a")) + close(done1) + }() + + w2 := httptest.NewRecorder() + done2 := make(chan struct{}) + go func() { + b.ServeHTTP(w2, newRequest("a")) + close(done2) + }() + + waitProcessed(t, b.testProcessed, 2) // both requests joined the active swap + <-a.runStarted + + cancel() + select { + case <-done1: + case <-time.After(time.Second): + t.Fatal("cancelled ServeHTTP did not return after ctx cancel") + } + + a.markReady() + select { + case <-done2: + case <-time.After(time.Second): + t.Fatal("non-cancelled ServeHTTP did not complete after swap") + } + if w2.Code != http.StatusOK { + t.Errorf("second request status=%d body=%q", w2.Code, w2.Body.String()) + } +} + +func TestBaseRouter_QueuedDifferentModel(t *testing.T) { + a := newFakeProcess("a") + pa := newFakeProcess("b") + + // Loading b must stop a. + planner := &stubPlanner{evict: map[string][]string{"b": {"a"}}} + b := newTestBase(t, map[string]process.Process{"a": a, "b": pa}, planner) + + // First request starts a swap to A; A's autoReady=false so it parks. + w1 := httptest.NewRecorder() + done1 := make(chan struct{}) + go func() { + b.ServeHTTP(w1, newRequest("a")) + close(done1) + }() + waitProcessed(t, b.testProcessed, 1) + <-a.runStarted + + // Second request for B should queue while A's swap is in flight. + w2 := httptest.NewRecorder() + done2 := make(chan struct{}) + go func() { + b.ServeHTTP(w2, newRequest("b")) + close(done2) + }() + waitProcessed(t, b.testProcessed, 1) + + if got := pa.runCalls.Load(); got != 0 { + t.Errorf("b started early: runCalls=%d want 0 while A's swap is pending", got) + } + + // Release A's swap. B's swap should then run. + a.markReady() + waitProcessed(t, b.testProcessed, 1) // swapDone for A → B's swap kicked off + <-pa.runStarted + + select { + case <-done1: + case <-time.After(time.Second): + t.Fatal("A request did not complete") + } + pa.markReady() + select { + case <-done2: + case <-time.After(time.Second): + t.Fatal("queued B request did not complete after A's swap") + } + if w2.Code != http.StatusOK { + t.Errorf("B status=%d body=%q", w2.Code, w2.Body.String()) + } + if got := a.stopCalls.Load(); got != 1 { + t.Errorf("a.stopCalls=%d want 1 (B's swap must stop A)", got) + } +} + +// TestBaseRouter_QueueCollation verifies that incoming requests of the form +// a, b, c, a, b, c collapse into three swaps (one per model) and that the +// second request for each model rides the fast path — either joining the +// active swap, or being pulled out of the queue when handleSwapDone promotes +// the next model. +func TestBaseRouter_QueueCollation(t *testing.T) { + a := newFakeProcess("a") + pb := newFakeProcess("b") + pc := newFakeProcess("c") + + // Each model evicts the other two so all swaps are mutually exclusive. + planner := &stubPlanner{evict: map[string][]string{ + "a": {"b", "c"}, + "b": {"a", "c"}, + "c": {"a", "b"}, + }} + b := newTestBase(t, map[string]process.Process{"a": a, "b": pb, "c": pc}, planner) + + var ( + completedMu sync.Mutex + completed []string + ) + record := func(id string) { + completedMu.Lock() + defer completedMu.Unlock() + completed = append(completed, id) + } + + ids := []string{"a", "b", "c", "a", "b", "c"} + var wg sync.WaitGroup + for _, id := range ids { + id := id + wg.Add(1) + go func() { + defer wg.Done() + w := httptest.NewRecorder() + b.ServeHTTP(w, newRequest(id)) + if w.Code != http.StatusOK { + t.Errorf("%s: status=%d body=%q", id, w.Code, w.Body.String()) + return + } + record(id) + }() + // Wait for run() to absorb this request before launching the next, + // so handlerCh receives them in launch order. + waitProcessed(t, b.testProcessed, 1) + } + + // All 6 are now parked in run()'s waiters/queue. Release each swap in + // sequence, waiting deterministically for each promotion to fire. + <-a.runStarted + a.markReady() + waitProcessed(t, b.testProcessed, 1) // swapDone(a) → b swap kicked off + + <-pb.runStarted + pb.markReady() + waitProcessed(t, b.testProcessed, 1) // swapDone(b) → c swap kicked off + + <-pc.runStarted + pc.markReady() + wg.Wait() + + if got := len(completed); got != 6 { + t.Fatalf("completed=%v want 6", completed) + } + + // run() fans out responses in model-grouped order (a1,a2 → b1,b2 → c1,c2) + // but waiter goroutines may be scheduled in any order after their respond + // channel fires, so completion order isn't deterministic. Per-model counts + // (combined with the runCalls checks below) are sufficient to prove queue + // collation collapsed each pair into a single swap. + aDone, bDone, cDone := 0, 0, 0 + for _, id := range completed { + switch id { + case "a": + aDone++ + case "b": + bDone++ + case "c": + cDone++ + } + } + if aDone != 2 || bDone != 2 || cDone != 2 { + t.Errorf("per-model counts: a=%d b=%d c=%d, want 2 each (order=%v)", aDone, bDone, cDone, completed) + } + + // Single swap per model — the second request for each must have ridden + // the fast path (joined active swap or joined a queued sibling), not + // triggered an extra Run. + if got := a.runCalls.Load(); got != 1 { + t.Errorf("a.runCalls=%d want 1", got) + } + if got := pb.runCalls.Load(); got != 1 { + t.Errorf("b.runCalls=%d want 1", got) + } + if got := pc.runCalls.Load(); got != 1 { + t.Errorf("c.runCalls=%d want 1", got) + } +} + +// TestBaseRouter_ConcurrentDisjointSwaps verifies that two requests with +// non-conflicting evict sets are loaded in parallel: both Run() calls happen +// before either process is marked ready. +func TestBaseRouter_ConcurrentDisjointSwaps(t *testing.T) { + a := newFakeProcess("a") + pb := newFakeProcess("b") + + // Empty evict sets for both: they can load in parallel. + b := newTestBase(t, map[string]process.Process{"a": a, "b": pb}, &stubPlanner{}) + + w1 := httptest.NewRecorder() + done1 := make(chan struct{}) + go func() { + b.ServeHTTP(w1, newRequest("a")) + close(done1) + }() + waitProcessed(t, b.testProcessed, 1) + + w2 := httptest.NewRecorder() + done2 := make(chan struct{}) + go func() { + b.ServeHTTP(w2, newRequest("b")) + close(done2) + }() + waitProcessed(t, b.testProcessed, 1) + + // Both swaps must have reached Run() before either is marked ready — + // proves they ran in parallel rather than serializing. + <-a.runStarted + <-pb.runStarted + + a.markReady() + pb.markReady() + + select { + case <-done1: + case <-time.After(time.Second): + t.Fatal("request A did not complete") + } + select { + case <-done2: + case <-time.After(time.Second): + t.Fatal("request B did not complete") + } + + if w1.Code != http.StatusOK { + t.Errorf("A status=%d body=%q", w1.Code, w1.Body.String()) + } + if w2.Code != http.StatusOK { + t.Errorf("B status=%d body=%q", w2.Code, w2.Body.String()) + } + if got := a.stopCalls.Load(); got != 0 { + t.Errorf("a.stopCalls=%d want 0 (parallel swap, no eviction)", got) + } + if got := pb.stopCalls.Load(); got != 0 { + t.Errorf("b.stopCalls=%d want 0 (parallel swap, no eviction)", got) + } +} + +// TestBaseRouter_QueueDrainPromotesMultiple verifies that completing one swap +// unblocks every queued request that no longer collides — they all start in +// parallel rather than one-per-completion. +func TestBaseRouter_QueueDrainPromotesMultiple(t *testing.T) { + a := newFakeProcess("a") + pb := newFakeProcess("b") + pc := newFakeProcess("c") + + // A's swap evicts both B and C, so B and C must queue. Once A finishes + // B and C themselves have empty evict sets, so they can start together. + planner := &stubPlanner{evict: map[string][]string{ + "a": {"b", "c"}, + }} + b := newTestBase(t, map[string]process.Process{"a": a, "b": pb, "c": pc}, planner) + + w1 := httptest.NewRecorder() + done1 := make(chan struct{}) + go func() { + b.ServeHTTP(w1, newRequest("a")) + close(done1) + }() + waitProcessed(t, b.testProcessed, 1) + <-a.runStarted + + // B and C arrive while A is loading. evict_b and evict_c are empty, + // but collidesWith returns true because they appear in A's evict set. + w2 := httptest.NewRecorder() + done2 := make(chan struct{}) + go func() { + b.ServeHTTP(w2, newRequest("b")) + close(done2) + }() + waitProcessed(t, b.testProcessed, 1) + + w3 := httptest.NewRecorder() + done3 := make(chan struct{}) + go func() { + b.ServeHTTP(w3, newRequest("c")) + close(done3) + }() + waitProcessed(t, b.testProcessed, 1) + + if got := pb.runCalls.Load(); got != 0 { + t.Errorf("b started early: runCalls=%d", got) + } + if got := pc.runCalls.Load(); got != 0 { + t.Errorf("c started early: runCalls=%d", got) + } + + // Release A. The swapDone handler should drain the queue and start + // both B and C in parallel. + a.markReady() + waitProcessed(t, b.testProcessed, 1) // swapDone(A) → drainQueue starts B and C + <-pb.runStarted + <-pc.runStarted + + pb.markReady() + pc.markReady() + + for i, ch := range []chan struct{}{done1, done2, done3} { + select { + case <-ch: + case <-time.After(time.Second): + t.Fatalf("request %d did not complete", i) + } + } +} + +// TestBaseRouter_Shutdown_FailsAllInFlight verifies that shutdown returns +// the shutdown error to every waiter on every active swap AND to every +// queued request. +func TestBaseRouter_Shutdown_FailsAllInFlight(t *testing.T) { + a := newFakeProcess("a") + pb := newFakeProcess("b") + pc := newFakeProcess("c") + + // a and b load in parallel (empty evicts). c collides with both. + planner := &stubPlanner{evict: map[string][]string{ + "c": {"a", "b"}, + }} + b := newTestBase(t, map[string]process.Process{"a": a, "b": pb, "c": pc}, planner) + + const waitersPer = 2 + var wg sync.WaitGroup + codes := make([]int, 0, 2*waitersPer+1) + var codesMu sync.Mutex + record := func(code int) { + codesMu.Lock() + codes = append(codes, code) + codesMu.Unlock() + } + + launch := func(model string) { + wg.Add(1) + go func() { + defer wg.Done() + w := httptest.NewRecorder() + b.ServeHTTP(w, newRequest(model)) + record(w.Code) + }() + } + + // Active swaps for a and b, each with 2 waiters. + for i := 0; i < waitersPer; i++ { + launch("a") + waitProcessed(t, b.testProcessed, 1) + } + for i := 0; i < waitersPer; i++ { + launch("b") + waitProcessed(t, b.testProcessed, 1) + } + // c collides with both → queues. + launch("c") + waitProcessed(t, b.testProcessed, 1) + + <-a.runStarted + <-pb.runStarted + + if err := b.Shutdown(time.Second); err != nil { + t.Fatalf("Shutdown: %v", err) + } + wg.Wait() + + codesMu.Lock() + defer codesMu.Unlock() + if len(codes) != 2*waitersPer+1 { + t.Fatalf("got %d responses, want %d", len(codes), 2*waitersPer+1) + } + for i, c := range codes { + if c == http.StatusOK { + t.Errorf("response %d: status=%d, want non-200 (shutdown)", i, c) + } + } +} + +// TestBaseRouter_NoSwapWhileServing verifies that an already-loaded model +// is not stopped to satisfy another model's swap while it is still handling +// a request. +// +// Sequence: +// 1. r1 (A) — A loads; ServeHTTP enters and is pinned via serveBlock. +// 2. r2 (B, planner: B evicts A) — must NOT cause A.Stop while r1 is live. +// 3. r3 (A) — arrives next; the existing code queues it because B's swap +// intent collides with A. +// 4. r1 released — A finishes r1, then r3 is served by A. +// 5. B's swap then proceeds; r2 is served by B. +// +// fakeProcess.stoppedWhileServing flips true if Stop is ever called while +// a ServeHTTP is in flight — a direct, race-free signal of the violation. +func TestBaseRouter_NoSwapWhileServing(t *testing.T) { + a := newFakeProcess("a") + // autoReady left false: we markReady manually after observing runStarted, + // so autoReady's setState(Ready) cannot race with a later Stop and leave + // A in Ready, masking the bug. + a.serveBlock = make(chan struct{}) + pb := newFakeProcess("b") + // Same reasoning for B: park its swap on WaitReady until we choose. + + planner := &stubPlanner{evict: map[string][]string{"b": {"a"}}} + b := newTestBase(t, map[string]process.Process{"a": a, "b": pb}, planner) + + // r1 — load A and enter its ServeHTTP (which blocks on serveBlock). + w1 := httptest.NewRecorder() + done1 := make(chan struct{}) + go func() { + b.ServeHTTP(w1, newRequest("a")) + close(done1) + }() + waitProcessed(t, b.testProcessed, 1) // handlerReq for r1 + <-a.runStarted + a.markReady() + waitProcessed(t, b.testProcessed, 1) // swapDone for A + <-a.serveStarted + + // r2 — would evict A. A must not be stopped while r1 is in flight. + w2 := httptest.NewRecorder() + done2 := make(chan struct{}) + go func() { + b.ServeHTTP(w2, newRequest("b")) + close(done2) + }() + waitProcessed(t, b.testProcessed, 1) + + // r3 — another request for A, arrives behind r2 and queues because + // B's swap intent (which evicts A) is recorded as active. + w3 := httptest.NewRecorder() + done3 := make(chan struct{}) + go func() { + b.ServeHTTP(w3, newRequest("a")) + close(done3) + }() + waitProcessed(t, b.testProcessed, 1) + + // Release r1 (and r3 if it is fast-pathed onto the still-loaded A). + // The router must hold off B's swap until A has drained. + close(a.serveBlock) + + select { + case <-done1: + case <-time.After(2 * time.Second): + t.Fatal("r1 did not complete after serveBlock release") + } + + // Wait for B.Run before marking it ready: markReady before Run would + // skip the Run path entirely and leave pb.runCalls at 0. In a correct + // implementation B's swap only starts after A has drained; in the + // current implementation it has already started — either way runStarted + // fires. + <-pb.runStarted + pb.markReady() + + select { + case <-done2: + case <-time.After(2 * time.Second): + t.Fatal("r2 did not complete after B marked ready") + } + select { + case <-done3: + case <-time.After(2 * time.Second): + t.Fatal("r3 did not complete") + } + + if w1.Code != http.StatusOK || w2.Code != http.StatusOK || w3.Code != http.StatusOK { + t.Fatalf("statuses: w1=%d w2=%d w3=%d", w1.Code, w2.Code, w3.Code) + } + if w1.Body.String() != "ok:a" { + t.Errorf("r1 body=%q want ok:a", w1.Body.String()) + } + if w3.Body.String() != "ok:a" { + t.Errorf("r3 body=%q want ok:a (r3 must be served by A)", w3.Body.String()) + } + if w2.Body.String() != "ok:b" { + t.Errorf("r2 body=%q want ok:b", w2.Body.String()) + } + if a.stoppedWhileServing.Load() { + t.Errorf("A.Stop was called while A was still handling a request — the router swapped out a busy process") + } +} + +func TestBaseRouter_ModelNotFound(t *testing.T) { + a := newFakeProcess("a") + b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{}) + + w := httptest.NewRecorder() + b.ServeHTTP(w, newRequest("unknown")) + + if w.Code != http.StatusNotFound { + t.Errorf("status=%d want %d body=%q", w.Code, http.StatusNotFound, w.Body.String()) + } +} + +func TestBaseRouter_Shutdown_StopsAllProcesses(t *testing.T) { + a := newFakeProcess("a") + a.markReady() + go a.Run(0) + pb := newFakeProcess("b") + pb.markReady() + go pb.Run(0) + + b := newTestBase(t, map[string]process.Process{"a": a, "b": pb}, &stubPlanner{}) + + if err := b.Shutdown(time.Second); err != nil { + t.Fatalf("Shutdown: %v", err) + } + if got := a.stopCalls.Load(); got != 1 { + t.Errorf("a.stopCalls=%d want 1", got) + } + if got := pb.stopCalls.Load(); got != 1 { + t.Errorf("b.stopCalls=%d want 1", got) + } + + // Subsequent ServeHTTP should report 5xx. + w := httptest.NewRecorder() + b.ServeHTTP(w, newRequest("a")) + if w.Code != http.StatusInternalServerError && w.Code != http.StatusServiceUnavailable { + t.Errorf("post-shutdown status=%d want 5xx body=%q", w.Code, w.Body.String()) + } + + // Second Shutdown should report already in progress. + if err := b.Shutdown(0); err == nil { + t.Errorf("second Shutdown returned nil, want error") + } +} diff --git a/internal/router/group.go b/internal/router/group.go new file mode 100644 index 0000000..db36caa --- /dev/null +++ b/internal/router/group.go @@ -0,0 +1,110 @@ +package router + +import ( + "fmt" + + "github.com/mostlygeek/llama-swap/internal/config" + "github.com/mostlygeek/llama-swap/internal/logmon" + "github.com/mostlygeek/llama-swap/internal/process" +) + +type Group struct { + *baseRouter +} + +func NewGroup(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Group, error) { + modelToGroup := make(map[string]string) + for gid, gcfg := range conf.Groups { + for _, mid := range gcfg.Members { + if existing, dup := modelToGroup[mid]; dup { + return nil, fmt.Errorf("model %q is in multiple groups: %q and %q", mid, existing, gid) + } + modelToGroup[mid] = gid + } + } + + planner := &groupPlanner{ + config: conf, + modelToGroup: modelToGroup, + } + + processes := make(map[string]process.Process, len(modelToGroup)) + base := newBaseRouter("group", conf, processes, planner, proxylog) + planner.processes = processes + + for mid := range modelToGroup { + modelCfg, _, ok := conf.FindConfig(mid) + if !ok { + base.shutdownFn() + return nil, fmt.Errorf("no model config for %q", mid) + } + procLog := logmon.NewWriter(upstreamlog) + p, err := process.New(base.shutdownCtx, mid, modelCfg, procLog, proxylog) + if err != nil { + base.shutdownFn() + return nil, fmt.Errorf("creating process for %q: %w", mid, err) + } + processes[mid] = p + } + + g := &Group{baseRouter: base} + go base.run() + return g, nil +} + +// groupPlanner decides evictions from static group configuration. +// +// Same-group siblings are stopped when the group has swap=true. Cross-group +// members are stopped only when the target's group is exclusive; loading a +// model from a non-exclusive group leaves running exclusive groups alone, +// matching the gotcha in the original ProcessGroup behaviour. +type groupPlanner struct { + config config.Config + modelToGroup map[string]string + processes map[string]process.Process +} + +func (p *groupPlanner) EvictionFor(target string, alsoRunning []string) []string { + tg := p.modelToGroup[target] + tgCfg := p.config.Groups[tg] + + seen := make(map[string]struct{}) + var result []string + consider := func(mID string) { + if mID == target { + return + } + if _, dup := seen[mID]; dup { + return + } + og := p.modelToGroup[mID] + switch { + case og == tg && tgCfg.Swap: + seen[mID] = struct{}{} + result = append(result, mID) + // the previous ProcessGroup behaviour did not unload exclusive groups + // when loading a non-exclusive model. This maintains that gotcha + // for backwards compatibility. The newer swap matrix approach does not + // have this issue. + case og != tg && tgCfg.Exclusive: + if ogCfg := p.config.Groups[og]; !ogCfg.Persistent { + seen[mID] = struct{}{} + result = append(result, mID) + } + } + } + + for mID, proc := range p.processes { + st := proc.State() + if st == process.StateStopped || st == process.StateShutdown { + continue + } + consider(mID) + } + for _, mID := range alsoRunning { + consider(mID) + } + return result +} + +func (p *groupPlanner) OnSwapStart(target string) {} diff --git a/internal/router/group_test.go b/internal/router/group_test.go new file mode 100644 index 0000000..5ba2c42 --- /dev/null +++ b/internal/router/group_test.go @@ -0,0 +1,331 @@ +package router + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/mostlygeek/llama-swap/internal/config" + "github.com/mostlygeek/llama-swap/internal/logmon" + "github.com/mostlygeek/llama-swap/internal/process" +) + +// newTestGroup builds a Group directly from the supplied processes and config, +// bypassing NewGroup's call to process.New. +func newTestGroup(t *testing.T, conf config.Config, processes map[string]process.Process) *Group { + t.Helper() + modelToGroup := make(map[string]string) + for gid, gcfg := range conf.Groups { + for _, mid := range gcfg.Members { + modelToGroup[mid] = gid + } + } + planner := &groupPlanner{ + config: conf, + modelToGroup: modelToGroup, + processes: processes, + } + base := newBaseRouter("group", conf, processes, planner, logmon.NewWriter(io.Discard)) + base.testProcessed = make(chan struct{}, 64) + g := &Group{baseRouter: base} + go base.run() + t.Cleanup(func() { + if !g.shuttingDown.Load() { + _ = g.Shutdown(time.Second) + } + }) + return g +} + +func TestGroup_NewGroup_DuplicateMembership(t *testing.T) { + conf := config.Config{ + Groups: map[string]config.GroupConfig{ + "g1": {Swap: true, Members: []string{"a"}}, + "g2": {Swap: true, Members: []string{"a"}}, + }, + Models: map[string]config.ModelConfig{ + "a": {}, + }, + } + log := logmon.NewWriter(io.Discard) + if _, err := NewGroup(conf, log, log); err == nil { + t.Fatalf("expected error for duplicate membership") + } +} + +func TestGroup_ServeHTTP_SwapStopsPrevious(t *testing.T) { + a := newFakeProcess("a") + a.markReady() + go a.Run(0) // park a Run goroutine so Stop has something to release + + b := newFakeProcess("b") + b.autoReady = true + + conf := config.Config{ + HealthCheckTimeout: 5, + Groups: map[string]config.GroupConfig{ + "g": {Swap: true, Exclusive: true, Members: []string{"a", "b"}}, + }, + } + g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b}) + + w := httptest.NewRecorder() + g.ServeHTTP(w, newRequest("b")) + + if w.Code != http.StatusOK { + t.Fatalf("status=%d body=%q", w.Code, w.Body.String()) + } + if got := a.stopCalls.Load(); got != 1 { + t.Errorf("a.stopCalls=%d want 1", got) + } + if got := b.runCalls.Load(); got != 1 { + t.Errorf("b.runCalls=%d want 1", got) + } + if got := b.serveCalls.Load(); got != 1 { + t.Errorf("b.serveCalls=%d want 1", got) + } +} + +func TestGroup_NonSwapGroup_NoStop(t *testing.T) { + a := newFakeProcess("a") + a.markReady() + + b := newFakeProcess("b") + b.autoReady = true + + conf := config.Config{ + HealthCheckTimeout: 5, + Groups: map[string]config.GroupConfig{ + "g": {Swap: false, Exclusive: false, Members: []string{"a", "b"}}, + }, + } + g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b}) + + w := httptest.NewRecorder() + g.ServeHTTP(w, newRequest("b")) + + if w.Code != http.StatusOK { + t.Fatalf("status=%d body=%q", w.Code, w.Body.String()) + } + if got := a.stopCalls.Load(); got != 0 { + t.Errorf("a.stopCalls=%d want 0 (swap=false should not stop siblings)", got) + } + if got := b.runCalls.Load(); got != 1 { + t.Errorf("b.runCalls=%d want 1", got) + } +} + +func TestGroup_CrossGroupExclusive(t *testing.T) { + a := newFakeProcess("a") + a.markReady() + go a.Run(0) + + b := newFakeProcess("b") + b.autoReady = true + + conf := config.Config{ + HealthCheckTimeout: 5, + Groups: map[string]config.GroupConfig{ + "g1": {Swap: true, Exclusive: true, Members: []string{"a"}}, + "g2": {Swap: true, Exclusive: true, Members: []string{"b"}}, + }, + } + g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b}) + + w := httptest.NewRecorder() + g.ServeHTTP(w, newRequest("b")) + + if w.Code != http.StatusOK { + t.Fatalf("status=%d body=%q", w.Code, w.Body.String()) + } + if got := a.stopCalls.Load(); got != 1 { + t.Errorf("a.stopCalls=%d want 1 (cross-group exclusive must stop)", got) + } +} + +// TestGroup_CrossGroupNonExclusiveParallel verifies that two requests for +// models in distinct non-exclusive groups load in parallel rather than +// serializing through the router's run loop. +func TestGroup_CrossGroupNonExclusiveParallel(t *testing.T) { + a := newFakeProcess("a") + pb := newFakeProcess("b") + + conf := config.Config{ + HealthCheckTimeout: 5, + Groups: map[string]config.GroupConfig{ + "g1": {Swap: true, Exclusive: false, Members: []string{"a"}}, + "g2": {Swap: true, Exclusive: false, Members: []string{"b"}}, + }, + } + g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": pb}) + + w1 := httptest.NewRecorder() + done1 := make(chan struct{}) + go func() { + g.ServeHTTP(w1, newRequest("a")) + close(done1) + }() + waitProcessed(t, g.testProcessed, 1) + + w2 := httptest.NewRecorder() + done2 := make(chan struct{}) + go func() { + g.ServeHTTP(w2, newRequest("b")) + close(done2) + }() + waitProcessed(t, g.testProcessed, 1) + + // Both groups load concurrently — both must reach Run() before either is + // marked ready. If the router still serialised, only one would proceed. + <-a.runStarted + <-pb.runStarted + + a.markReady() + pb.markReady() + + for i, ch := range []chan struct{}{done1, done2} { + select { + case <-ch: + case <-time.After(time.Second): + t.Fatalf("request %d did not complete", i) + } + } + if got := a.stopCalls.Load(); got != 0 { + t.Errorf("a.stopCalls=%d want 0 (parallel groups don't evict each other)", got) + } + if got := pb.stopCalls.Load(); got != 0 { + t.Errorf("b.stopCalls=%d want 0 (parallel groups don't evict each other)", got) + } +} + +// TestGroup_SameGroupSwapSerialises verifies that two same-group requests +// (Swap=true) serialise even when both arrive while neither has reached +// StateStarting yet — the alsoRunning hint to the planner closes that race. +func TestGroup_SameGroupSwapSerialises(t *testing.T) { + a := newFakeProcess("a") + pb := newFakeProcess("b") + + conf := config.Config{ + HealthCheckTimeout: 5, + Groups: map[string]config.GroupConfig{ + "g": {Swap: true, Exclusive: false, Members: []string{"a", "b"}}, + }, + } + g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": pb}) + + w1 := httptest.NewRecorder() + done1 := make(chan struct{}) + go func() { + g.ServeHTTP(w1, newRequest("a")) + close(done1) + }() + waitProcessed(t, g.testProcessed, 1) + + // Request B arrives before A transitions to StateStarting in the process + // state machine. Without the alsoRunning hint, the planner would not see + // A as running, and B would start in parallel, violating Swap=true. + w2 := httptest.NewRecorder() + done2 := make(chan struct{}) + go func() { + g.ServeHTTP(w2, newRequest("b")) + close(done2) + }() + waitProcessed(t, g.testProcessed, 1) + + if got := pb.runCalls.Load(); got != 0 { + t.Errorf("b started in parallel: runCalls=%d want 0", got) + } + + <-a.runStarted + a.markReady() + waitProcessed(t, g.testProcessed, 1) // swapDone(a) → b promoted + <-pb.runStarted + pb.markReady() + + for i, ch := range []chan struct{}{done1, done2} { + select { + case <-ch: + case <-time.After(time.Second): + t.Fatalf("request %d did not complete", i) + } + } + if got := a.stopCalls.Load(); got != 1 { + t.Errorf("a.stopCalls=%d want 1 (b's swap must stop a)", got) + } +} + +// TestGroup_PersistentNotEvicted verifies that a group with persistent=true +// is never evicted when another exclusive group starts loading. The running +// model in the persistent group stays alive alongside the new one. +func TestGroup_PersistentNotEvicted(t *testing.T) { + a := newFakeProcess("a") + a.markReady() + go a.Run(0) + + b := newFakeProcess("b") + b.autoReady = true + + conf := config.Config{ + HealthCheckTimeout: 5, + Groups: map[string]config.GroupConfig{ + "persist": {Swap: true, Exclusive: false, Persistent: true, Members: []string{"a"}}, + "other": {Swap: true, Exclusive: true, Members: []string{"b"}}, + }, + } + g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b}) + + w := httptest.NewRecorder() + g.ServeHTTP(w, newRequest("b")) + + if w.Code != http.StatusOK { + t.Fatalf("status=%d body=%q", w.Code, w.Body.String()) + } + if got := a.stopCalls.Load(); got != 0 { + t.Errorf("a.stopCalls=%d want 0 (persistent group must not be evicted)", got) + } + if a.State() != process.StateStarting && a.State() != process.StateReady { + t.Errorf("a state=%s want still running", a.State()) + } + if got := b.runCalls.Load(); got != 1 { + t.Errorf("b.runCalls=%d want 1", got) + } +} + +// TestGroup_NonExclusiveDoesNotUnloadExclusive pins a backwards-compatible +// gotcha from the original ProcessGroup: when a model in a non-exclusive group +// is loaded, any running exclusive group keeps running. The two coexist. +func TestGroup_NonExclusiveDoesNotUnloadExclusive(t *testing.T) { + a := newFakeProcess("a") + a.markReady() + go a.Run(0) + + b := newFakeProcess("b") + b.autoReady = true + + conf := config.Config{ + HealthCheckTimeout: 5, + Groups: map[string]config.GroupConfig{ + "g1": {Swap: true, Exclusive: true, Members: []string{"a"}}, + "g2": {Swap: true, Exclusive: false, Members: []string{"b"}}, + }, + } + g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b}) + + w := httptest.NewRecorder() + g.ServeHTTP(w, newRequest("b")) + + if w.Code != http.StatusOK { + t.Fatalf("status=%d body=%q", w.Code, w.Body.String()) + } + if got := a.stopCalls.Load(); got != 0 { + t.Errorf("a.stopCalls=%d want 0 (non-exclusive target must not unload exclusive group)", got) + } + if a.State() != process.StateStarting && a.State() != process.StateReady { + t.Errorf("a state=%s want still running", a.State()) + } + if got := b.runCalls.Load(); got != 1 { + t.Errorf("b.runCalls=%d want 1", got) + } +} diff --git a/internal/router/helpers_test.go b/internal/router/helpers_test.go new file mode 100644 index 0000000..ce1df6b --- /dev/null +++ b/internal/router/helpers_test.go @@ -0,0 +1,205 @@ +package router + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/mostlygeek/llama-swap/internal/logmon" + "github.com/mostlygeek/llama-swap/internal/process" +) + +// fakeProcess is an in-memory implementation of process.Process used to drive +// the routers through their state machine without spawning real upstreams. +type fakeProcess struct { + id string + + mu sync.Mutex + state process.ProcessState + readyCh chan struct{} + stopCh chan struct{} + runStarted chan struct{} // closed on the first Run call + stopStarted chan struct{} // closed on the first Stop call + + autoReady bool + + // serveBlock, when non-nil, makes ServeHTTP receive from it before + // writing its response. Tests use this to hold a request in-flight. + // Closing the channel releases every blocked ServeHTTP caller. + serveBlock chan struct{} + // serveStarted is closed on the first ServeHTTP entry, letting tests + // wait deterministically for the handler to begin executing. + serveStarted chan struct{} + // stopBlock, when non-nil, makes Stop receive from it (after signalling + // stopStarted) before completing. Tests use this to prove that several + // Stop calls can be in flight simultaneously. + stopBlock chan struct{} + + runCalls atomic.Int32 + stopCalls atomic.Int32 + serveCalls atomic.Int32 + + // inFlightServe counts ServeHTTP calls currently inside the handler. + // stoppedWhileServing flips true if Stop is ever called while that + // counter is non-zero — a direct, race-free observation of the + // "swap mid-request" anti-property. + inFlightServe atomic.Int32 + stoppedWhileServing atomic.Bool +} + +func newFakeProcess(id string) *fakeProcess { + return &fakeProcess{ + id: id, + state: process.StateStopped, + readyCh: make(chan struct{}), + stopCh: make(chan struct{}), + runStarted: make(chan struct{}), + stopStarted: make(chan struct{}), + serveStarted: make(chan struct{}), + } +} + +func (f *fakeProcess) setState(s process.ProcessState) { + f.mu.Lock() + defer f.mu.Unlock() + f.state = s + if s == process.StateReady { + select { + case <-f.readyCh: + default: + close(f.readyCh) + } + } +} + +func (f *fakeProcess) State() process.ProcessState { + f.mu.Lock() + defer f.mu.Unlock() + return f.state +} + +func (f *fakeProcess) markReady() { f.setState(process.StateReady) } + +func (f *fakeProcess) Run(_ time.Duration) error { + f.runCalls.Add(1) + f.mu.Lock() + if f.state != process.StateStopped { + s := f.state + f.mu.Unlock() + return fmt.Errorf("fakeProcess %s: Run called while %s", f.id, s) + } + f.state = process.StateStarting + sc := f.stopCh + select { + case <-f.runStarted: + default: + close(f.runStarted) + } + f.mu.Unlock() + + if f.autoReady { + f.setState(process.StateReady) + } + <-sc + return nil +} + +func (f *fakeProcess) Stop(_ time.Duration) error { + f.stopCalls.Add(1) + if f.inFlightServe.Load() > 0 { + f.stoppedWhileServing.Store(true) + } + f.mu.Lock() + select { + case <-f.stopStarted: + default: + close(f.stopStarted) + } + f.mu.Unlock() + + // Test hook: hold Stop here so the test can prove multiple Stops are + // in flight at the same time before any of them complete. + if f.stopBlock != nil { + <-f.stopBlock + } + + f.mu.Lock() + defer f.mu.Unlock() + if f.state == process.StateStopped { + return nil + } + f.state = process.StateStopped + select { + case <-f.stopCh: + default: + close(f.stopCh) + } + return nil +} + +func (f *fakeProcess) WaitReady(ctx context.Context) error { + f.mu.Lock() + if f.state == process.StateReady { + f.mu.Unlock() + return nil + } + rc := f.readyCh + f.mu.Unlock() + select { + case <-rc: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func (f *fakeProcess) Logger() *logmon.Monitor { return logmon.NewWriter(io.Discard) } + +func (f *fakeProcess) ServeHTTP(w http.ResponseWriter, _ *http.Request) { + f.serveCalls.Add(1) + f.inFlightServe.Add(1) + defer f.inFlightServe.Add(-1) + f.mu.Lock() + select { + case <-f.serveStarted: + default: + close(f.serveStarted) + } + f.mu.Unlock() + if f.serveBlock != nil { + <-f.serveBlock + } + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, "ok:%s", f.id) +} + +// waitProcessed drains n events from ch, fataling on timeout. One event fires +// per handlerReq or swapDone fully absorbed by run(). +func waitProcessed(t *testing.T, ch chan struct{}, n int) { + t.Helper() + for i := 0; i < n; i++ { + select { + case <-ch: + case <-time.After(2 * time.Second): + t.Fatalf("waitProcessed: only %d/%d events received", i, n) + } + } +} + +func newRequest(model string) *http.Request { + body := fmt.Sprintf(`{"model":%q}`, model) + r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(body)) + r.Header.Set("Content-Type", "application/json") + return r +} + +func newRequestCtx(ctx context.Context, model string) *http.Request { + return newRequest(model).WithContext(ctx) +} diff --git a/internal/router/loading.go b/internal/router/loading.go new file mode 100644 index 0000000..99c6ee8 --- /dev/null +++ b/internal/router/loading.go @@ -0,0 +1,249 @@ +package router + +import ( + "context" + "encoding/json" + "fmt" + "math/rand" + "net/http" + "strings" + "sync" + "time" + + "github.com/mostlygeek/llama-swap/internal/logmon" +) + +var loadingPaths = []string{ + "/v1/chat/completions", +} + +func isLoadingPath(path string) bool { + for _, p := range loadingPaths { + if strings.HasPrefix(path, p) { + return true + } + } + return false +} + +type loadingWriter struct { + hasWritten bool + writer http.ResponseWriter + req *http.Request + ctx context.Context + logger *logmon.Monitor + modelName string + startTime time.Time + + pendingMu sync.Mutex + pendingUpdate string + + // closed by start when the goroutine finishes (after cleanup messages) + done chan struct{} + + // test-only: closed when start enters its loop + loopStarted chan struct{} + // test-only: override the 1s tick interval + tickDuration time.Duration + // test-only: override character streaming speed (0 = no delay) + charPerSecond float64 +} + +func newLoadingWriter(logger *logmon.Monitor, modelName string, w http.ResponseWriter, req *http.Request) *loadingWriter { + s := &loadingWriter{ + writer: w, + req: req, + ctx: req.Context(), + logger: logger, + modelName: modelName, + startTime: time.Now(), + tickDuration: 750 * time.Millisecond, + charPerSecond: 75, + } + + s.Header().Set("Content-Type", "text/event-stream") + s.Header().Set("Cache-Control", "no-cache") + s.Header().Set("Connection", "keep-alive") + s.WriteHeader(http.StatusOK) + s.sendLine("━━━━━") + s.sendLine(fmt.Sprintf("llama-swap loading model: %s", modelName)) + return s +} + +func (s *loadingWriter) setUpdate(msg string) { + s.pendingMu.Lock() + s.pendingUpdate = msg + s.pendingMu.Unlock() +} + +func (s *loadingWriter) start(ctx context.Context) { + s.done = make(chan struct{}) + defer close(s.done) + + defer func() { + // Skip cleanup writes if the client disconnected — the connection + // is being torn down and flushing against it will panic. + if s.ctx.Err() != nil { + return + } + duration := time.Since(s.startTime) + s.sendData("\n") + s.sendLine(fmt.Sprintf("Done! (%.2fs)", duration.Seconds())) + s.sendLine("━━━━━") + s.sendLine(" ") + }() + + remarks := make([]string, len(loadingRemarks)) + copy(remarks, loadingRemarks) + rand.Shuffle(len(remarks), func(i, j int) { + remarks[i], remarks[j] = remarks[j], remarks[i] + }) + ri := 0 + + nextRemarkIn := time.Duration(2+rand.Intn(4)) * time.Second + lastRemarkTime := time.Time{} + + ticker := time.NewTicker(s.tickDuration) + defer ticker.Stop() + + if s.loopStarted != nil { + close(s.loopStarted) + } + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + s.pendingMu.Lock() + update := s.pendingUpdate + s.pendingUpdate = "" + s.pendingMu.Unlock() + + if update != "" { + s.sendData("\n") + s.sendInline(update) + s.sendData(" ") + lastRemarkTime = time.Now() + nextRemarkIn = time.Duration(5+rand.Intn(5)) * time.Second + } else if time.Since(lastRemarkTime) >= nextRemarkIn { + remark := remarks[ri%len(remarks)] + ri++ + s.sendData("\n") + s.sendInline(remark) + s.sendData(" ") + lastRemarkTime = time.Now() + nextRemarkIn = time.Duration(5+rand.Intn(5)) * time.Second + } else { + s.sendData(".") + } + } + } +} + +func (s *loadingWriter) waitForCompletion(timeout time.Duration) bool { + if s.done == nil { + return true + } + select { + case <-s.done: + return true + case <-time.After(timeout): + return false + } +} + +func (s *loadingWriter) sendInline(text string) { + chunkSize := 10 + if s.charPerSecond > 0 { + chunkSize = max(3, int(s.charPerSecond)/15) + } + + runes := []rune(text) + for i := 0; i < len(runes); { + select { + case <-s.ctx.Done(): + return + default: + } + + end := i + chunkSize + if end > len(runes) { + end = len(runes) + } + chunk := string(runes[i:end]) + s.sendData(chunk) + i = end + + if i < len(runes) && s.charPerSecond > 0 { + time.Sleep(time.Duration(float64(time.Second) * float64(len(chunk)) / s.charPerSecond)) + } + } +} + +func (s *loadingWriter) sendLine(line string) { + if line == "" { + s.sendData("\n") + return + } + s.sendInline(line) + s.sendData("\n") +} + +func (s *loadingWriter) sendData(data string) { + type Delta struct { + ReasoningContent string `json:"reasoning_content"` + } + type Choice struct { + Delta Delta `json:"delta"` + } + type SSEMessage struct { + Choices []Choice `json:"choices"` + } + + msg := SSEMessage{ + Choices: []Choice{ + { + Delta: Delta{ + ReasoningContent: data, + }, + }, + }, + } + + jsonData, err := json.Marshal(msg) + if err != nil { + s.logger.Errorf("<%s> Failed to marshal SSE message: %v", s.modelName, err) + return + } + + _, err = fmt.Fprintf(s.writer, "data: %s\n\n", jsonData) + if err != nil { + s.logger.Debugf("<%s> Failed to write SSE data (client likely disconnected): %v", s.modelName, err) + return + } + s.Flush() +} + +func (s *loadingWriter) Header() http.Header { + return s.writer.Header() +} + +func (s *loadingWriter) Write(data []byte) (int, error) { + return s.writer.Write(data) +} + +func (s *loadingWriter) WriteHeader(statusCode int) { + if s.hasWritten { + return + } + s.hasWritten = true + s.writer.WriteHeader(statusCode) + s.Flush() +} + +func (s *loadingWriter) Flush() { + if flusher, ok := s.writer.(http.Flusher); ok { + flusher.Flush() + } +} diff --git a/internal/router/loading_remarks.go b/internal/router/loading_remarks.go new file mode 100644 index 0000000..9fed748 --- /dev/null +++ b/internal/router/loading_remarks.go @@ -0,0 +1,133 @@ +package router + +var loadingRemarks = []string{ + "Still faster than your last standup meeting", + "Reticulating splines", + "Waking up the hamsters", + "Teaching the model manners", + "Convincing the GPU to participate", + "Loading weights (they're heavy)", + "Please enjoy this elevator music in your head", + "Pretending to be productive", + "Reading the entire internet, page by page", + "Staring at the abyss, the abyss is buffering", + "Applying layer after layer of disembodied cognition", + "Remembering everything it forgot during quantization", + "Counting to 405 billion, one parameter at a time", + "Summoning the stochastic parroting", + "Hold on, the GPU is questioning its existence", + "Deciding which facts to hallucinate today", + "Untangling the transformer spaghetti", + "Warming up the token soup", + "Your prompt is in a queue, behind 7 billion other thoughts", + "Running `sudo apt-get install intelligence`", + "Defragmenting the latent space", + "Polishing each matrix multiplication by hand", + "Whispering sweet nothings to the attention heads", + "Aligning with human values, one reluctant epoch at a time", + "The model is thinking about what it's about to think about", + "Loading... and by loading we mean making you wait", + "Spinning up the cloud GPU, please be patient while we burn your credits", + "Applying duct tape to the context window", + "Bribing the GPU scheduler for a timeslice", + "Would you like to hear a fun fact while we load? Too bad.", + "Hot swapping your sanity for an LLM", + "Compressing optimism into FP16", + "Ignoring 90% of the attention to save you 50% of the time", + "Counting the exact same thing three times just to be sure", + "Sorry, the inference you have reached is not in service", + "Rotating the positional encodings counterclockwise for good luck", + "Your call is very important to us. Please continue to hold.", + "Unpacking the blobs. All 300GB of them.", + "Initializing the thing that initializes the other thing", + "Converting electricity into existential dread", + "Flattening the curve... wait, the tensor. Flattening the tensor.", + "Fetching the fetch of a fetch, callback hell edition", + "The GPU is at 100%. The fan is now a helicopter.", + "Baking the weights at 350° for a golden-brown inference", + "Recalibrating the confidence of things it's still wrong about", + "Have you tried turning it off and on again? No? Good, wait here.", + "Simulating deep thought by pausing dramatically", + "Loading the model that knows more than you but still can't count r's in 'strawberry'", + "Convincing CUDA to cooperate. This may take a while.", + "VRAM: 23.9GB used of 24GB. Living on the edge.", + "Processing your request with the urgency of a DMV employee", + "This model was trained on the entire internet, including that embarrassing blog you wrote in 2008", + "Dispatching tokens through a series of increasingly confused matrix multiplies", + "Gently lowering your expectations", + "Applying softmax to our feelings about this load time", + "Autoregressively generating disappointment, one token at a time", + "The magic is happening. Somewhere. Probably.", + "Synchronizing the parallel processes that run in parallel but really don't", + "Calculating the meaning of life. Spoiler: it's 42, but we're double-checking.", + "Loading... just like it said 30 seconds ago. And will say 30 seconds from now.", + "Pre-warming the cache so the first query is only slightly slower than the rest", + "Have you considered that maybe your question wasn't worth all this compute?", + "Downloading more RAM (no, really, we're mmap-ing the weights)", + "Translating your prompt into math it barely understands", + "Estimating your time remaining with 0% accuracy", + "Buffering enthusiasm", + "Model is loading. Go make some coffee. Or a three-course meal.", + "Tokenizing the dictionary, filing a grievance on behalf of 'antidisestablishmentarianism'", + "Polling for readiness in a loop that would make your CS professor weep", + "Performing percussive maintenance on the attention mechanism", + "This loading screen is singlehandedly reversing climate progress", + "Decompressing the hopes and dreams of thousands of underpaid labelers", + "Filling the key-value cache with the ghost of prompts past", + "Currently at step 3 of 9,742 of loading. We'll get there. Eventually.", + "If you stare at the spinner, it spins slower. It's science.", + "Multiplying matricies with the enthusiasm of a teenager doing chores", + "Applying `torch.nap()` until the model feels refreshed", + "Reacquainting the model with the concept of 'facts' it forgot during fine-tuning", + "Sorry for the wait. No, wait, we're not actually sorry.", + "Your GPU is now a space heater with a side hustle in linear algebra", + "Allocating memory like a billionaire allocates tax avoidance strategies", + "The model saw \"As an AI language model\" and won't stop saying it now", + "Installing dependencies you didn't know existed and will never use again", + "Re-reading 'Attention Is All You Need' for the 400th time", + "Convincing the embedding layer that context is overrated", + "Manually untangling the residual connections with a tiny comb", + "On hold with the cloud provider trying to explain why 8 H100s isn't enough", + "Adjusting temperatures: model is 0.7, server room is 104°F", + "Please hold while we justify this electricity bill to accounting", + "Stacking decoder blocks like a Jenga tower at a LAN party", + "Compensating for your lack of patience with our lack of speed", + "This is a loading screen comment. Loading screens have comments now. Welcome to the future.", + "Processing the entire works of Shakespeare backwards just in case", + "The model is loading slower than your last `npm install`", + "Rehearsing plausible-sounding explanations for why it got everything wrong", + "Populating the context with filler while you wait for actual content", + "Optimizing for BLEU score, which definitely correlates with making you laugh", + "Generating an embedding for each and every letter of the alphabet, individually", + "Coming soon: llama-swap v2 with actual performance improvements. Probably.", + "Loading a model larger than your attention span", + "Performing a seance to invoke the spirit of Geoff Hinton", + "Did you know loading screens were invented to prevent users from smashing their monitors? Now you do.", + "Converting all the internet's bad opinions into a surprisingly useful autocomplete", + "Laying down each layer with the care of a Michelin-starred pastry chef", + "Checking if the model still thinks birds are government drones. Yep.", + "Activating the neurons responsible for 'I cannot assist with that request'", + "This model was trained on the same internet that brought you Rickrolling. You're welcome.", + "Realigning the alignment so it aligns with the previous alignment", + "Running `nvidia-smi` and sighing heavily", + "If you close your eyes, the loading bar moves faster. Proven by science.", + "EULA said 'by using this software you agree to wait forever' and you clicked Accept", + "Zipping the GPUs to make them go faster", + "Padding the context window with existential padding", + "We could have used a smaller model but someone wanted 'quality'", + "Disentangling the latent space into something resembling coherence", + "Slow is smooth, smooth is fast, but this is just slow", + "Memory-mapping like it's a AAA title from 2012", + "Your patience has been tokenized and added to the training set. Thank you for your contribution.", + "Loading is CPU-bound and your CPU is busy regretting its life choices", + "Exploring the high-dimensional manifold of ways to say 'just a moment'", + "The model is experiencing a brief but intense moment of imposter syndrome", + "Initializing 7B parameters by rolling 7B 16-sided dice", + "Panic! at the disk I/O", + "Intelligence is loading... your definition of intelligence may vary", + "This model was distilled. Unlike your patience, which is evaporating.", + "Unzipping the model. It's a .gguf file, not a metaphor.", + "Running inference on the concept of 'soon' to estimate remaining time", + "Loading with all the speed of a government-funded IT project", + "A blank terminal is a terrible thing to waste. Here's a loading message instead.", +} diff --git a/internal/router/loading_test.go b/internal/router/loading_test.go new file mode 100644 index 0000000..57dc3bf --- /dev/null +++ b/internal/router/loading_test.go @@ -0,0 +1,328 @@ +package router + +import ( + "bufio" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/mostlygeek/llama-swap/internal/logmon" +) + +func TestLoadingWriter_SSEHeadersAndInitialMessage(t *testing.T) { + logger := logmon.NewWriter(io.Discard) + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + lw := newLoadingWriter(logger, "test-model", w, req) + + if ct := lw.Header().Get("Content-Type"); ct != "text/event-stream" { + t.Errorf("Content-Type: want text/event-stream, got %q", ct) + } + if cc := lw.Header().Get("Cache-Control"); cc != "no-cache" { + t.Errorf("Cache-Control: want no-cache, got %q", cc) + } + if conn := lw.Header().Get("Connection"); conn != "keep-alive" { + t.Errorf("Connection: want keep-alive, got %q", conn) + } + + body := w.Body.String() + if !strings.HasPrefix(body, "data: ") { + t.Errorf("expected SSE data: prefix, got: %s", body) + } + + content := extractStreamedContent(body) + if !strings.Contains(content, "━━━━━\n") { + t.Errorf("missing separator in streamed content: %q", content) + } + if !strings.Contains(content, "llama-swap loading model: test-model\n") { + t.Errorf("missing initial message in streamed content: %q", content) + } +} + +func TestLoadingWriter_WriteHeaderOnce(t *testing.T) { + logger := logmon.NewWriter(io.Discard) + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + lw := newLoadingWriter(logger, "test-model", w, req) + lw.WriteHeader(http.StatusCreated) + + if w.Code != http.StatusOK { + t.Errorf("first WriteHeader: want %d, got %d", http.StatusOK, w.Code) + } +} + +func TestLoadingWriter_WritePassthrough(t *testing.T) { + logger := logmon.NewWriter(io.Discard) + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + lw := newLoadingWriter(logger, "test-model", w, req) + lw.Write([]byte("hello")) + lw.Flush() + + body := w.Body.String() + if !strings.Contains(body, "hello") { + t.Errorf("Write passthrough failed, body: %s", body) + } +} + +func TestLoadingWriter_StartStopsOnCancel(t *testing.T) { + logger := logmon.NewWriter(io.Discard) + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + lw := newLoadingWriter(logger, "test-model", w, req) + lw.tickDuration = 10 * time.Millisecond + lw.loopStarted = make(chan struct{}) + + ctx, cancel := context.WithCancel(context.Background()) + + go lw.start(ctx) + <-lw.loopStarted + cancel() + + if !lw.waitForCompletion(time.Second) { + t.Fatal("waitForCompletion timed out") + } + + body := w.Body.String() + if !strings.Contains(body, "Done!") { + t.Errorf("expected Done! message, body: %s", body) + } +} + +func TestLoadingWriter_StartShowsSetUpdate(t *testing.T) { + logger := logmon.NewWriter(io.Discard) + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + lw := newLoadingWriter(logger, "test-model", w, req) + lw.tickDuration = 10 * time.Millisecond + lw.charPerSecond = 0 + lw.loopStarted = make(chan struct{}) + + ctx, cancel := context.WithCancel(context.Background()) + go lw.start(ctx) + <-lw.loopStarted + + lw.setUpdate("custom status message") + time.Sleep(50 * time.Millisecond) + cancel() + + if !lw.waitForCompletion(time.Second) { + t.Fatal("waitForCompletion timed out") + } + + body := w.Body.String() + content := extractStreamedContent(body) + if !strings.Contains(content, "custom status message") { + t.Errorf("expected setUpdate message in output, got: %q", content) + } +} + +func TestLoadingWriter_SendDataFormat(t *testing.T) { + logger := logmon.NewWriter(io.Discard) + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + lw := newLoadingWriter(logger, "test-model", w, req) + lw.sendData("hello world") + + body := w.Body.String() + if !strings.Contains(body, `"reasoning_content":"hello world"`) { + t.Errorf("expected reasoning_content in SSE data, body: %s", body) + } + if !strings.HasPrefix(body, "data: ") { + t.Errorf("expected data: prefix, got: %s", body) + } +} + +func TestLoadingWriter_SendLine(t *testing.T) { + logger := logmon.NewWriter(io.Discard) + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + lw := newLoadingWriter(logger, "test-model", w, req) + lw.charPerSecond = 0 + + // Capture only the content from this sendLine call + before := w.Body.Len() + lw.sendLine("line content") + after := w.Body.Len() + chunkBody := w.Body.String()[before:after] + + content := extractStreamedContent(chunkBody) + if content != "line content\n" { + t.Errorf("expected complete streamed line, got: %q", content) + } +} + +func TestLoadingWriter_FlushesPeriodicallyDuringStatusUpdates(t *testing.T) { + logger := logmon.NewWriter(io.Discard) + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + lw := newLoadingWriter(logger, "test-model", w, req) + lw.tickDuration = 10 * time.Millisecond + lw.charPerSecond = 0 + lw.loopStarted = make(chan struct{}) + + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + go func() { + lw.start(ctx) + close(done) + }() + + <-lw.loopStarted + time.Sleep(50 * time.Millisecond) + cancel() + <-done + + body := w.Body.String() + lines := countSSEMessages(body) + if lines < 2 { + t.Errorf("expected multiple SSE messages from periodic updates, got %d", lines) + } +} + +func TestLoadingWriter_ReqStored(t *testing.T) { + logger := logmon.NewWriter(io.Discard) + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + lw := newLoadingWriter(logger, "test-model", w, req) + if lw.req != req { + t.Fatal("req not stored") + } +} + +func TestIsLoadingPath(t *testing.T) { + tests := []struct { + path string + want bool + }{ + {"/v1/chat/completions", true}, + {"/v1/chat/completions/extra", true}, + {"/v1/completions", false}, + {"/v1/embeddings", false}, + {"/health", false}, + {"", false}, + } + + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + if got := isLoadingPath(tt.path); got != tt.want { + t.Errorf("isLoadingPath(%q) = %v, want %v", tt.path, got, tt.want) + } + }) + } +} + +func TestExtractContext_Streaming_GET(t *testing.T) { + tests := []struct { + name string + query string + wantStreaming bool + }{ + {"streaming true", "model=llama3&stream=true", true}, + {"streaming false", "model=llama3&stream=false", false}, + {"no stream param", "model=llama3", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil) + got, err := ExtractContext(r) + if err != nil { + t.Fatalf("ExtractContext: %v", err) + } + if got.Streaming != tt.wantStreaming { + t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming) + } + }) + } +} + +func TestExtractContext_Streaming_JSON(t *testing.T) { + tests := []struct { + name string + body string + wantStreaming bool + }{ + {"streaming true", `{"model":"llama3","stream":true}`, true}, + {"streaming false", `{"model":"llama3","stream":false}`, false}, + {"no stream param", `{"model":"llama3"}`, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body)) + r.Header.Set("Content-Type", "application/json") + got, err := ExtractContext(r) + if err != nil { + t.Fatalf("ExtractContext: %v", err) + } + if got.Streaming != tt.wantStreaming { + t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming) + } + }) + } +} + +func TestExtractContext_Streaming_URLEncodedForm(t *testing.T) { + r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader("model=whisper-1&stream=true")) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + got, err := ExtractContext(r) + if err != nil { + t.Fatalf("ExtractContext: %v", err) + } + if !got.Streaming { + t.Error("Streaming should be true") + } +} + +func countSSEMessages(s string) int { + scanner := bufio.NewScanner(strings.NewReader(s)) + count := 0 + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "data: ") { + count++ + } + } + return count +} + +func extractStreamedContent(body string) string { + var result strings.Builder + scanner := bufio.NewScanner(strings.NewReader(body)) + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data: ") { + continue + } + jsonData := strings.TrimPrefix(line, "data: ") + var msg struct { + Choices []struct { + Delta struct { + ReasoningContent string `json:"reasoning_content"` + } `json:"delta"` + } `json:"choices"` + } + if err := json.Unmarshal([]byte(jsonData), &msg); err != nil { + continue + } + if len(msg.Choices) > 0 { + result.WriteString(msg.Choices[0].Delta.ReasoningContent) + } + } + return result.String() +} diff --git a/internal/router/matrix.go b/internal/router/matrix.go new file mode 100644 index 0000000..2badfc1 --- /dev/null +++ b/internal/router/matrix.go @@ -0,0 +1,100 @@ +package router + +import ( + "fmt" + "sort" + + "github.com/mostlygeek/llama-swap/internal/config" + "github.com/mostlygeek/llama-swap/internal/logmon" + "github.com/mostlygeek/llama-swap/internal/process" +) + +type Matrix struct { + *baseRouter +} + +func NewMatrix(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Matrix, error) { + if conf.Matrix == nil { + return nil, fmt.Errorf("matrix router requires a matrix configuration") + } + + planner := &matrixPlanner{ + solver: newMatrixSolver(conf.ExpandedSets, conf.Matrix.ResolvedEvictCosts()), + logger: proxylog, + } + + // Build a process for every model in the config. Any model can run alone + // even if it is not part of a set; this mirrors proxy.NewMatrix. + processes := make(map[string]process.Process, len(conf.Models)) + base := newBaseRouter("matrix", conf, processes, planner, proxylog) + planner.processes = processes + + for mid, modelCfg := range conf.Models { + procLog := logmon.NewWriter(upstreamlog) + p, err := process.New(base.shutdownCtx, mid, modelCfg, procLog, proxylog) + if err != nil { + base.shutdownFn() + return nil, fmt.Errorf("creating process for %q: %w", mid, err) + } + processes[mid] = p + } + + r := &Matrix{baseRouter: base} + go base.run() + return r, nil +} + +// matrixPlanner decides evictions by asking the matrix solver against the +// current running set. +type matrixPlanner struct { + solver *matrixSolver + processes map[string]process.Process + logger *logmon.Monitor +} + +func (p *matrixPlanner) EvictionFor(target string, alsoRunning []string) []string { + return p.solver.Solve(target, p.runningSet(alsoRunning)).Evict +} + +func (p *matrixPlanner) OnSwapStart(target string) { + running := p.runningModels() + result := p.solver.Solve(target, running) + switch { + case len(result.Evict) > 0: + p.logger.Infof("matrix: model=%s set=%s dsl=%q evict=%v target=%v cost=%d", + target, result.SetName, result.DSL, result.Evict, result.TargetSet, result.TotalCost) + case len(running) == 0: + p.logger.Infof("matrix: model=%s starting (no models running)", target) + default: + p.logger.Debugf("matrix: model=%s already running in set=%s dsl=%q", target, result.SetName, result.DSL) + } +} + +func (p *matrixPlanner) runningModels() []string { + return p.runningSet(nil) +} + +// runningSet returns the union of live processes (State != Stopped/Shutdown) +// and any extra IDs the baseRouter has already committed to loading but which +// the process state machine has not yet reflected. +func (p *matrixPlanner) runningSet(alsoRunning []string) []string { + seen := make(map[string]struct{}, len(p.processes)) + var running []string + for id, proc := range p.processes { + st := proc.State() + if st == process.StateStopped || st == process.StateShutdown { + continue + } + seen[id] = struct{}{} + running = append(running, id) + } + for _, id := range alsoRunning { + if _, dup := seen[id]; dup { + continue + } + seen[id] = struct{}{} + running = append(running, id) + } + sort.Strings(running) + return running +} diff --git a/internal/router/matrix_solver.go b/internal/router/matrix_solver.go new file mode 100644 index 0000000..a5d9054 --- /dev/null +++ b/internal/router/matrix_solver.go @@ -0,0 +1,132 @@ +package router + +import ( + "slices" + + "github.com/mostlygeek/llama-swap/internal/config" +) + +// matrixSolver contains pure swap-decision logic with no Process dependencies. +// It is safe for concurrent reads after construction. +type matrixSolver struct { + expandedSets []config.ExpandedSet // all valid model combinations + evictCosts map[string]int // real model name -> eviction cost (default 1) + modelToSets map[string][]int // model name -> indices into expandedSets +} + +func newMatrixSolver(expandedSets []config.ExpandedSet, evictCosts map[string]int) *matrixSolver { + modelToSets := make(map[string][]int) + for i, es := range expandedSets { + for _, model := range es.Models { + modelToSets[model] = append(modelToSets[model], i) + } + } + + return &matrixSolver{ + expandedSets: expandedSets, + evictCosts: evictCosts, + modelToSets: modelToSets, + } +} + +// solveResult describes what the solver decided. +type solveResult struct { + Evict []string // running models that must be stopped + TargetSet []string // the chosen set of models (for informational purposes) + SetName string // name of the chosen set + DSL string // original DSL expression for the chosen set + TotalCost int // total eviction cost +} + +// Solve determines which models to evict when a model is requested. +// +// Algorithm: +// 1. If requestedModel is already running, no eviction needed. +// 2. Find all sets containing requestedModel. +// 3. If no sets found, the model runs alone; evict all running models. +// 4. For each candidate set, compute cost = sum of evict_costs for running +// models NOT in that set. +// 5. Pick lowest cost. Ties broken by definition order (index in expandedSets). +// 6. Return models to evict and the chosen set. +func (s *matrixSolver) Solve(requestedModel string, runningModels []string) solveResult { + if slices.Contains(runningModels, requestedModel) { + setName, dsl := s.findMatchingSet(requestedModel, runningModels) + return solveResult{ + TargetSet: runningModels, + SetName: setName, + DSL: dsl, + } + } + + candidateIndices := s.modelToSets[requestedModel] + + // Model not in any set: runs alone, evict everything. + if len(candidateIndices) == 0 { + evict := make([]string, len(runningModels)) + copy(evict, runningModels) + return solveResult{ + Evict: evict, + TargetSet: []string{requestedModel}, + } + } + + bestCost := -1 + bestIdx := -1 + + for _, idx := range candidateIndices { + setModels := s.expandedSets[idx].Models + cost := 0 + for _, running := range runningModels { + if !slices.Contains(setModels, running) { + cost += s.evictCost(running) + } + } + + if bestCost < 0 || cost < bestCost || (cost == bestCost && idx < bestIdx) { + bestCost = cost + bestIdx = idx + } + } + + chosen := s.expandedSets[bestIdx] + var evict []string + for _, running := range runningModels { + if !slices.Contains(chosen.Models, running) { + evict = append(evict, running) + } + } + + return solveResult{ + Evict: evict, + TargetSet: chosen.Models, + SetName: chosen.SetName, + DSL: chosen.DSL, + TotalCost: bestCost, + } +} + +// findMatchingSet finds the expanded set that contains all running models. +// Returns the set name and DSL, or empty strings if no match. +func (s *matrixSolver) findMatchingSet(requestedModel string, runningModels []string) (string, string) { + for _, idx := range s.modelToSets[requestedModel] { + set := s.expandedSets[idx] + allInSet := true + for _, m := range runningModels { + if !slices.Contains(set.Models, m) { + allInSet = false + break + } + } + if allInSet { + return set.SetName, set.DSL + } + } + return "", "" +} + +func (s *matrixSolver) evictCost(model string) int { + if cost, ok := s.evictCosts[model]; ok { + return cost + } + return 1 +} diff --git a/internal/router/matrix_test.go b/internal/router/matrix_test.go new file mode 100644 index 0000000..fb688d4 --- /dev/null +++ b/internal/router/matrix_test.go @@ -0,0 +1,244 @@ +package router + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/mostlygeek/llama-swap/internal/config" + "github.com/mostlygeek/llama-swap/internal/logmon" + "github.com/mostlygeek/llama-swap/internal/process" +) + +// newTestMatrix builds a Matrix router from supplied processes, bypassing +// NewMatrix's call to process.New. +func newTestMatrix(t *testing.T, conf config.Config, expanded []config.ExpandedSet, evictCosts map[string]int, processes map[string]process.Process) *Matrix { + t.Helper() + logger := logmon.NewWriter(io.Discard) + planner := &matrixPlanner{ + solver: newMatrixSolver(expanded, evictCosts), + processes: processes, + logger: logger, + } + base := newBaseRouter("matrix", conf, processes, planner, logger) + base.testProcessed = make(chan struct{}, 64) + r := &Matrix{baseRouter: base} + go base.run() + t.Cleanup(func() { + if !r.shuttingDown.Load() { + _ = r.Shutdown(time.Second) + } + }) + return r +} + +func baseMatrixConfig() config.Config { + return config.Config{ + HealthCheckTimeout: 5, + Matrix: &config.MatrixConfig{}, + } +} + +// TestMatrix_SwapEvictsConflicting verifies that loading a model triggers +// eviction of running models that are not in any shared set with it. +func TestMatrix_SwapEvictsConflicting(t *testing.T) { + a := newFakeProcess("a") + a.markReady() + go a.Run(0) // park a Run goroutine so Stop has something to release + + b := newFakeProcess("b") + b.autoReady = true + + // Two single-model sets: a and b never coexist, so loading b must evict a. + expanded := []config.ExpandedSet{ + {SetName: "s_a", DSL: "a", Models: []string{"a"}}, + {SetName: "s_b", DSL: "b", Models: []string{"b"}}, + } + r := newTestMatrix(t, baseMatrixConfig(), expanded, nil, map[string]process.Process{"a": a, "b": b}) + + w := httptest.NewRecorder() + r.ServeHTTP(w, newRequest("b")) + + if w.Code != http.StatusOK { + t.Fatalf("status=%d body=%q", w.Code, w.Body.String()) + } + if got := a.stopCalls.Load(); got != 1 { + t.Errorf("a.stopCalls=%d want 1", got) + } + if got := b.runCalls.Load(); got != 1 { + t.Errorf("b.runCalls=%d want 1", got) + } +} + +// TestMatrix_CoexistInSet verifies that a model is not evicted when the target +// shares a set with it (the fast path applies if the target is already ready). +func TestMatrix_CoexistInSet(t *testing.T) { + a := newFakeProcess("a") + a.markReady() + go a.Run(0) + + b := newFakeProcess("b") + b.autoReady = true + + // Both fit in s_ab, so b's swap should not stop a. + expanded := []config.ExpandedSet{ + {SetName: "s_ab", DSL: "a & b", Models: []string{"a", "b"}}, + } + r := newTestMatrix(t, baseMatrixConfig(), expanded, nil, map[string]process.Process{"a": a, "b": b}) + + w := httptest.NewRecorder() + r.ServeHTTP(w, newRequest("b")) + + if w.Code != http.StatusOK { + t.Fatalf("status=%d body=%q", w.Code, w.Body.String()) + } + if got := a.stopCalls.Load(); got != 0 { + t.Errorf("a.stopCalls=%d want 0 (coexists with b)", got) + } + if got := b.runCalls.Load(); got != 1 { + t.Errorf("b.runCalls=%d want 1", got) + } +} + +// TestMatrix_CoexistingSetParallel verifies that two models that share an +// expanded set load in parallel — the solver returns empty Evict for both, +// the collision predicate clears them, and both swaps run together. +func TestMatrix_CoexistingSetParallel(t *testing.T) { + a := newFakeProcess("a") + pb := newFakeProcess("b") + + expanded := []config.ExpandedSet{ + {SetName: "s_ab", DSL: "a & b", Models: []string{"a", "b"}}, + } + r := newTestMatrix(t, baseMatrixConfig(), expanded, nil, map[string]process.Process{"a": a, "b": pb}) + + w1 := httptest.NewRecorder() + done1 := make(chan struct{}) + go func() { + r.ServeHTTP(w1, newRequest("a")) + close(done1) + }() + waitProcessed(t, r.testProcessed, 1) + + w2 := httptest.NewRecorder() + done2 := make(chan struct{}) + go func() { + r.ServeHTTP(w2, newRequest("b")) + close(done2) + }() + waitProcessed(t, r.testProcessed, 1) + + <-a.runStarted + <-pb.runStarted + + a.markReady() + pb.markReady() + + for i, ch := range []chan struct{}{done1, done2} { + select { + case <-ch: + case <-time.After(time.Second): + t.Fatalf("request %d did not complete", i) + } + } + if got := a.stopCalls.Load(); got != 0 { + t.Errorf("a.stopCalls=%d want 0 (coexists with b)", got) + } + if got := pb.stopCalls.Load(); got != 0 { + t.Errorf("b.stopCalls=%d want 0 (coexists with a)", got) + } +} + +// TestMatrix_IncompatibleQueues verifies that the second request for a model +// that cannot coexist with the in-flight first model queues until the first +// completes, and then evicts it. This exercises the alsoRunning hint via the +// matrix solver's union into runningSet. +func TestMatrix_IncompatibleQueues(t *testing.T) { + a := newFakeProcess("a") + pb := newFakeProcess("b") + + expanded := []config.ExpandedSet{ + {SetName: "s_a", DSL: "a", Models: []string{"a"}}, + {SetName: "s_b", DSL: "b", Models: []string{"b"}}, + } + r := newTestMatrix(t, baseMatrixConfig(), expanded, nil, map[string]process.Process{"a": a, "b": pb}) + + w1 := httptest.NewRecorder() + done1 := make(chan struct{}) + go func() { + r.ServeHTTP(w1, newRequest("a")) + close(done1) + }() + waitProcessed(t, r.testProcessed, 1) + + // B arrives before A transitions to StateStarting. The solver sees A via + // alsoRunning and returns evict=[a], so collidesWith forces B to queue. + w2 := httptest.NewRecorder() + done2 := make(chan struct{}) + go func() { + r.ServeHTTP(w2, newRequest("b")) + close(done2) + }() + waitProcessed(t, r.testProcessed, 1) + + if got := pb.runCalls.Load(); got != 0 { + t.Errorf("b started in parallel: runCalls=%d want 0", got) + } + + <-a.runStarted + a.markReady() + waitProcessed(t, r.testProcessed, 1) // swapDone(a) → b promoted, evicts a + <-pb.runStarted + pb.markReady() + + for i, ch := range []chan struct{}{done1, done2} { + select { + case <-ch: + case <-time.After(time.Second): + t.Fatalf("request %d did not complete", i) + } + } + if got := a.stopCalls.Load(); got != 1 { + t.Errorf("a.stopCalls=%d want 1 (b's swap must stop a)", got) + } +} + +// TestMatrixSolver_TieBreakDefinitionOrder pins the solver's tie-break rule: +// when multiple candidate sets have equal eviction cost, the earlier-defined +// set wins. +func TestMatrixSolver_TieBreakDefinitionOrder(t *testing.T) { + expanded := []config.ExpandedSet{ + {SetName: "first", DSL: "a & b", Models: []string{"a", "b"}}, + {SetName: "second", DSL: "a & c", Models: []string{"a", "c"}}, + } + s := newMatrixSolver(expanded, nil) + + // No models running, request "a": both sets have cost 0 and contain a. + // Definition order: "first" wins. + result := s.Solve("a", nil) + if result.SetName != "first" { + t.Errorf("SetName=%q want %q", result.SetName, "first") + } +} + +// TestMatrixSolver_EvictCostsPreferred verifies that higher evict costs steer +// the solver toward a cheaper set. +func TestMatrixSolver_EvictCostsPreferred(t *testing.T) { + // b is expensive to evict; c is cheap. Request "a" with both b and c + // running. The solver should pick the set that keeps b. + expanded := []config.ExpandedSet{ + {SetName: "a_with_c", DSL: "a & c", Models: []string{"a", "c"}}, // would evict b (cost 10) + {SetName: "a_with_b", DSL: "a & b", Models: []string{"a", "b"}}, // would evict c (cost 1) + } + s := newMatrixSolver(expanded, map[string]int{"b": 10, "c": 1}) + + result := s.Solve("a", []string{"b", "c"}) + if result.SetName != "a_with_b" { + t.Errorf("SetName=%q want %q (keep expensive b)", result.SetName, "a_with_b") + } + if len(result.Evict) != 1 || result.Evict[0] != "c" { + t.Errorf("Evict=%v want [c]", result.Evict) + } +} diff --git a/internal/router/peer.go b/internal/router/peer.go new file mode 100644 index 0000000..a017cca --- /dev/null +++ b/internal/router/peer.go @@ -0,0 +1,188 @@ +package router + +import ( + "context" + "fmt" + "net" + "net/http" + "net/http/httputil" + "runtime" + "sort" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/mostlygeek/llama-swap/internal/config" + "github.com/mostlygeek/llama-swap/internal/logmon" +) + +type peerMember struct { + peerID string + reverseProxy *httputil.ReverseProxy + apiKey string +} + +type Peer struct { + cfg config.Config + logger *logmon.Monitor + peers map[string]*peerMember + + shutdownCtx context.Context + shutdownFn context.CancelFunc + shuttingDown atomic.Bool + inflight sync.WaitGroup +} + +func NewPeer(cfg config.Config, logger *logmon.Monitor) (*Peer, error) { + peers := cfg.Peers + modelMap := make(map[string]*peerMember) + + peerIDs := make([]string, 0, len(peers)) + for peerID := range peers { + peerIDs = append(peerIDs, peerID) + } + sort.Strings(peerIDs) + + for _, peerID := range peerIDs { + peer := peers[peerID] + + peerTransport := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: time.Duration(peer.Timeouts.Connect) * time.Second, + KeepAlive: time.Duration(peer.Timeouts.KeepAlive) * time.Second, + }).DialContext, + TLSHandshakeTimeout: time.Duration(peer.Timeouts.TLSHandshake) * time.Second, + ResponseHeaderTimeout: time.Duration(peer.Timeouts.ResponseHeader) * time.Second, + ExpectContinueTimeout: time.Duration(peer.Timeouts.ExpectContinue) * time.Second, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: time.Duration(peer.Timeouts.IdleConn) * time.Second, + } + + reverseProxy := httputil.NewSingleHostReverseProxy(peer.ProxyURL) + reverseProxy.Transport = peerTransport + + originalDirector := reverseProxy.Director + reverseProxy.Director = func(req *http.Request) { + originalDirector(req) + req.Host = req.URL.Host + } + + reverseProxy.ModifyResponse = func(resp *http.Response) error { + if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") { + resp.Header.Set("X-Accel-Buffering", "no") + } + return nil + } + + reverseProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { + logger.Warnf("peer %s: proxy error: %v", peerID, err) + errMsg := fmt.Sprintf("peer proxy error: %v", err) + if runtime.GOOS == "darwin" && strings.Contains(err.Error(), "connect: no route to host") { + errMsg += " (hint: on macOS, check System Settings > Privacy & Security > Local Network permissions)" + } + http.Error(w, errMsg, http.StatusBadGateway) + } + + pp := &peerMember{ + peerID: peerID, + reverseProxy: reverseProxy, + apiKey: peer.ApiKey, + } + + for _, modelID := range peer.Models { + if _, found := modelMap[modelID]; found { + logger.Warnf("peer %s: model %s already mapped to another peer, skipping", peerID, modelID) + continue + } + modelMap[modelID] = pp + } + } + + shutdownCtx, shutdownFn := context.WithCancel(context.Background()) + + return &Peer{ + cfg: cfg, + logger: logger, + peers: modelMap, + shutdownCtx: shutdownCtx, + shutdownFn: shutdownFn, + }, nil +} + +func (r *Peer) Handles(model string) bool { + _, ok := r.peers[model] + return ok +} + +func (r *Peer) Shutdown(timeout time.Duration) error { + if !r.shuttingDown.CompareAndSwap(false, true) { + return fmt.Errorf("shutdown already in progress") + } + + if timeout == 0 { + r.shutdownFn() + r.inflight.Wait() + return nil + } + + done := make(chan struct{}) + go func() { + r.inflight.Wait() + close(done) + }() + + select { + case <-done: + return nil + case <-time.After(timeout): + r.shutdownFn() + r.inflight.Wait() + return fmt.Errorf("peer shutdown timed out after %v", timeout) + } +} + +func (r *Peer) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if r.shuttingDown.Load() { + SendError(w, req, fmt.Errorf("peer proxy is shutting down")) + return + } + r.inflight.Add(1) + defer r.inflight.Done() + + data, err := FetchContext(req, r.cfg) + if err != nil { + SendError(w, req, err) + return + } + + pp, found := r.peers[data.ModelID] + if !found { + r.logger.Warnf("peer model not found: %s", data.ModelID) + SendError(w, req, ErrNoPeerModelFound) + return + } + + r.logger.Debugf("peer: routing model %s to peer %s", data.ModelID, pp.peerID) + + if pp.apiKey != "" { + req.Header.Set("Authorization", "Bearer "+pp.apiKey) + req.Header.Set("x-api-key", pp.apiKey) + } + + // Cancel the proxy request when the client disconnects or shutdown times out. + // AfterFunc links both parent contexts to our child without a goroutine leak. + ctx, cancel := context.WithCancel(context.Background()) + stopReq := context.AfterFunc(req.Context(), cancel) + stopShutdown := context.AfterFunc(r.shutdownCtx, cancel) + req = req.WithContext(ctx) + + pp.reverseProxy.ServeHTTP(w, req) + + stopShutdown() + stopReq() + cancel() +} diff --git a/internal/router/peer_test.go b/internal/router/peer_test.go new file mode 100644 index 0000000..74527bc --- /dev/null +++ b/internal/router/peer_test.go @@ -0,0 +1,611 @@ +package router + +import ( + "net/http" + "net/http/httptest" + "net/url" + "os" + "strings" + "sync" + "testing" + "time" + + "github.com/mostlygeek/llama-swap/internal/config" + "github.com/mostlygeek/llama-swap/internal/logmon" +) + +var testLogger = logmon.NewWriter(os.Stdout) + +func init() { + testLogger.SetLogLevel(logmon.LevelWarn) +} + +func TestNewPeer_EmptyPeers(t *testing.T) { + pr, err := NewPeer(config.Config{}, testLogger) + if err != nil { + t.Fatal(err) + } + if pr == nil { + t.Fatal("expected non-nil Peer") + } + if len(pr.peers) != 0 { + t.Fatalf("expected empty peers map, got %d entries", len(pr.peers)) + } +} + +func TestNewPeer_SinglePeer(t *testing.T) { + proxyURL, _ := url.Parse("http://peer1.example.com:8080") + peers := config.PeerDictionaryConfig{ + "peer1": config.PeerConfig{ + Proxy: "http://peer1.example.com:8080", + ProxyURL: proxyURL, + ApiKey: "test-key", + Models: []string{"model-a", "model-b"}, + }, + } + + pr, err := NewPeer(config.Config{Peers: peers}, testLogger) + if err != nil { + t.Fatal(err) + } + if len(pr.peers) != 2 { + t.Fatalf("expected 2 entries, got %d", len(pr.peers)) + } + if _, ok := pr.peers["model-a"]; !ok { + t.Error("expected model-a to be mapped") + } + if _, ok := pr.peers["model-b"]; !ok { + t.Error("expected model-b to be mapped") + } + if _, ok := pr.peers["model-c"]; ok { + t.Error("expected model-c to not be mapped") + } +} + +func TestNewPeer_MultiplePeers(t *testing.T) { + proxyURL1, _ := url.Parse("http://peer1.example.com:8080") + proxyURL2, _ := url.Parse("http://peer2.example.com:8080") + peers := config.PeerDictionaryConfig{ + "peer1": config.PeerConfig{ + Proxy: "http://peer1.example.com:8080", + ProxyURL: proxyURL1, + Models: []string{"model-a", "model-b"}, + }, + "peer2": config.PeerConfig{ + Proxy: "http://peer2.example.com:8080", + ProxyURL: proxyURL2, + Models: []string{"model-c", "model-d"}, + }, + } + + pr, err := NewPeer(config.Config{Peers: peers}, testLogger) + if err != nil { + t.Fatal(err) + } + if len(pr.peers) != 4 { + t.Fatalf("expected 4 entries, got %d", len(pr.peers)) + } + for _, m := range []string{"model-a", "model-b", "model-c", "model-d"} { + if _, ok := pr.peers[m]; !ok { + t.Errorf("expected %s to be mapped", m) + } + } +} + +func TestNewPeer_DuplicateModel(t *testing.T) { + proxyURL1, _ := url.Parse("http://peer1.example.com:8080") + proxyURL2, _ := url.Parse("http://peer2.example.com:8080") + peers := config.PeerDictionaryConfig{ + "alpha-peer": config.PeerConfig{ + Proxy: "http://peer1.example.com:8080", + ProxyURL: proxyURL1, + Models: []string{"duplicate-model"}, + }, + "beta-peer": config.PeerConfig{ + Proxy: "http://peer2.example.com:8080", + ProxyURL: proxyURL2, + Models: []string{"duplicate-model"}, + }, + } + + pr, err := NewPeer(config.Config{Peers: peers}, testLogger) + if err != nil { + t.Fatal(err) + } + if len(pr.peers) != 1 { + t.Fatalf("expected 1 entry for duplicate model, got %d", len(pr.peers)) + } + if _, ok := pr.peers["duplicate-model"]; !ok { + t.Error("expected duplicate-model to be mapped") + } +} + +func TestPeer_ServeHTTP_Success(t *testing.T) { + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("response from peer")) + })) + defer testServer.Close() + + proxyURL, _ := url.Parse(testServer.URL) + peers := config.PeerDictionaryConfig{ + "peer1": config.PeerConfig{ + Proxy: testServer.URL, + ProxyURL: proxyURL, + Models: []string{"test-model"}, + }, + } + + pr, err := NewPeer(config.Config{Peers: peers}, testLogger) + if err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest("POST", "/v1/chat/completions", nil) + *req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"})) + w := httptest.NewRecorder() + + pr.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d", w.Code) + } + if w.Body.String() != "response from peer" { + t.Errorf("expected 'response from peer', got %q", w.Body.String()) + } +} + +func TestPeer_ServeHTTP_ModelNotFoundInContext(t *testing.T) { + pr, err := NewPeer(config.Config{}, testLogger) + if err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest("POST", "/v1/chat/completions", nil) + w := httptest.NewRecorder() + + pr.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("expected 404, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestPeer_ServeHTTP_PeerModelNotFound(t *testing.T) { + pr, err := NewPeer(config.Config{}, testLogger) + if err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest("POST", "/v1/chat/completions", nil) + *req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "nonexistent-model", ModelID: "nonexistent-model"})) + w := httptest.NewRecorder() + + pr.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("expected 404, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestPeer_ServeHTTP_ApiKeyInjection(t *testing.T) { + var receivedAuthHeader string + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedAuthHeader = r.Header.Get("Authorization") + w.WriteHeader(http.StatusOK) + })) + defer testServer.Close() + + proxyURL, _ := url.Parse(testServer.URL) + peers := config.PeerDictionaryConfig{ + "peer1": config.PeerConfig{ + Proxy: testServer.URL, + ProxyURL: proxyURL, + ApiKey: "secret-api-key", + Models: []string{"test-model"}, + }, + } + + pr, err := NewPeer(config.Config{Peers: peers}, testLogger) + if err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest("POST", "/v1/chat/completions", nil) + *req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"})) + w := httptest.NewRecorder() + + pr.ServeHTTP(w, req) + + if receivedAuthHeader != "Bearer secret-api-key" { + t.Errorf("expected 'Bearer secret-api-key', got %q", receivedAuthHeader) + } +} + +func TestPeer_ServeHTTP_NoApiKey(t *testing.T) { + var receivedAuthHeader string + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedAuthHeader = r.Header.Get("Authorization") + w.WriteHeader(http.StatusOK) + })) + defer testServer.Close() + + proxyURL, _ := url.Parse(testServer.URL) + peers := config.PeerDictionaryConfig{ + "peer1": config.PeerConfig{ + Proxy: testServer.URL, + ProxyURL: proxyURL, + ApiKey: "", + Models: []string{"test-model"}, + }, + } + + pr, err := NewPeer(config.Config{Peers: peers}, testLogger) + if err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest("POST", "/v1/chat/completions", nil) + *req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"})) + w := httptest.NewRecorder() + + pr.ServeHTTP(w, req) + + if receivedAuthHeader != "" { + t.Errorf("expected no auth header, got %q", receivedAuthHeader) + } +} + +func TestPeer_ServeHTTP_HostHeaderSet(t *testing.T) { + var receivedHost string + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHost = r.Host + w.WriteHeader(http.StatusOK) + })) + defer testServer.Close() + + proxyURL, _ := url.Parse(testServer.URL) + peers := config.PeerDictionaryConfig{ + "peer1": config.PeerConfig{ + Proxy: testServer.URL, + ProxyURL: proxyURL, + Models: []string{"test-model"}, + }, + } + + pr, err := NewPeer(config.Config{Peers: peers}, testLogger) + if err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest("POST", "/v1/chat/completions", nil) + *req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"})) + w := httptest.NewRecorder() + + pr.ServeHTTP(w, req) + + if !strings.HasPrefix(receivedHost, "127.0.0.1:") { + t.Errorf("expected Host to start with '127.0.0.1:', got %q", receivedHost) + } +} + +func TestPeer_ServeHTTP_SSEHeaderModification(t *testing.T) { + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + })) + defer testServer.Close() + + proxyURL, _ := url.Parse(testServer.URL) + peers := config.PeerDictionaryConfig{ + "peer1": config.PeerConfig{ + Proxy: testServer.URL, + ProxyURL: proxyURL, + Models: []string{"test-model"}, + }, + } + + pr, err := NewPeer(config.Config{Peers: peers}, testLogger) + if err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest("POST", "/v1/chat/completions", nil) + *req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"})) + w := httptest.NewRecorder() + + pr.ServeHTTP(w, req) + + if w.Header().Get("X-Accel-Buffering") != "no" { + t.Errorf("expected X-Accel-Buffering=no, got %q", w.Header().Get("X-Accel-Buffering")) + } +} + +func TestPeer_ServeHTTP_ShutdownRejectsNewRequests(t *testing.T) { + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer testServer.Close() + + proxyURL, _ := url.Parse(testServer.URL) + peers := config.PeerDictionaryConfig{ + "peer1": config.PeerConfig{ + Proxy: testServer.URL, + ProxyURL: proxyURL, + Models: []string{"test-model"}, + }, + } + + pr, err := NewPeer(config.Config{Peers: peers}, testLogger) + if err != nil { + t.Fatal(err) + } + + err = pr.Shutdown(0) + if err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest("POST", "/v1/chat/completions", nil) + *req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"})) + w := httptest.NewRecorder() + + pr.ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("expected 500, got %d: %s", w.Code, w.Body.String()) + } + if !strings.Contains(w.Body.String(), "shutting down") { + t.Errorf("expected 'shutting down' in body, got %q", w.Body.String()) + } +} + +func TestPeer_ServeHTTP_WaitsForInflightDuringShutdown(t *testing.T) { + started := make(chan struct{}) + released := make(chan struct{}) + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + close(started) + <-released + w.WriteHeader(http.StatusOK) + })) + defer testServer.Close() + + proxyURL, _ := url.Parse(testServer.URL) + peers := config.PeerDictionaryConfig{ + "peer1": config.PeerConfig{ + Proxy: testServer.URL, + ProxyURL: proxyURL, + Models: []string{"test-model"}, + }, + } + + pr, err := NewPeer(config.Config{Peers: peers}, testLogger) + if err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest("POST", "/v1/chat/completions", nil) + *req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"})) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + w := httptest.NewRecorder() + pr.ServeHTTP(w, req) + }() + + <-started + + shutdownDone := make(chan error, 1) + go func() { + shutdownDone <- pr.Shutdown(500 * time.Millisecond) + }() + + // Shutdown should be waiting on inflight. If it finished already something is wrong. + time.Sleep(100 * time.Millisecond) + select { + case err := <-shutdownDone: + t.Errorf("shutdown completed before inflight finished: %v", err) + default: + } + + close(released) + wg.Wait() + + select { + case err := <-shutdownDone: + if err != nil { + t.Errorf("shutdown errored after inflight completed: %v", err) + } + case <-time.After(2 * time.Second): + t.Error("shutdown did not complete after inflight finished") + } +} + +func TestPeer_ServeHTTP_ShutdownTimeoutCancelsInflight(t *testing.T) { + started := make(chan struct{}) + released := make(chan struct{}) + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + close(started) + <-released + w.WriteHeader(http.StatusOK) + })) + defer testServer.Close() + + proxyURL, _ := url.Parse(testServer.URL) + peers := config.PeerDictionaryConfig{ + "peer1": config.PeerConfig{ + Proxy: testServer.URL, + ProxyURL: proxyURL, + Models: []string{"test-model"}, + }, + } + + pr, err := NewPeer(config.Config{Peers: peers}, testLogger) + if err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest("POST", "/v1/chat/completions", nil) + *req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"})) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + w := httptest.NewRecorder() + pr.ServeHTTP(w, req) + }() + + <-started + + err = pr.Shutdown(100 * time.Millisecond) + if err == nil { + t.Error("expected timeout error from shutdown") + } + + close(released) + wg.Wait() +} + +func TestPeer_ShutdownMultiple(t *testing.T) { + pr, err := NewPeer(config.Config{}, testLogger) + if err != nil { + t.Fatal(err) + } + + err = pr.Shutdown(0) + if err != nil { + t.Fatal(err) + } + + err = pr.Shutdown(0) + if err == nil { + t.Error("expected error on second shutdown") + } + if !strings.Contains(err.Error(), "already in progress") { + t.Errorf("expected 'already in progress', got %q", err.Error()) + } +} + +func TestPeer_ServeHTTP_ModelExtractedFromBody(t *testing.T) { + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok")) + })) + defer testServer.Close() + + proxyURL, _ := url.Parse(testServer.URL) + peers := config.PeerDictionaryConfig{ + "peer1": config.PeerConfig{ + Proxy: testServer.URL, + ProxyURL: proxyURL, + Models: []string{"extracted-model"}, + }, + } + + pr, err := NewPeer(config.Config{Peers: peers}, testLogger) + if err != nil { + t.Fatal(err) + } + + body := strings.NewReader(`{"model":"extracted-model","prompt":"hello"}`) + req := httptest.NewRequest("POST", "/v1/chat/completions", body) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + pr.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestPeer_ServeHTTP_ContextOverridesBodyModel(t *testing.T) { + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok")) + })) + defer testServer.Close() + + proxyURL, _ := url.Parse(testServer.URL) + peers := config.PeerDictionaryConfig{ + "peer1": config.PeerConfig{ + Proxy: testServer.URL, + ProxyURL: proxyURL, + Models: []string{"context-model"}, + }, + "peer2": config.PeerConfig{ + Proxy: testServer.URL, + ProxyURL: proxyURL, + Models: []string{"body-model"}, + }, + } + + pr, err := NewPeer(config.Config{Peers: peers}, testLogger) + if err != nil { + t.Fatal(err) + } + + body := strings.NewReader(`{"model":"body-model","prompt":"hello"}`) + req := httptest.NewRequest("POST", "/v1/chat/completions", body) + req.Header.Set("Content-Type", "application/json") + *req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "context-model", ModelID: "context-model"})) + w := httptest.NewRecorder() + + pr.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestNewPeer_CustomTimeouts(t *testing.T) { + proxyURL, _ := url.Parse("http://localhost:8080") + peers := config.PeerDictionaryConfig{ + "test-peer": config.PeerConfig{ + Proxy: "http://localhost:8080", + ProxyURL: proxyURL, + Models: []string{"model1"}, + Timeouts: config.TimeoutsConfig{ + Connect: 45, + ResponseHeader: 300, + TLSHandshake: 15, + ExpectContinue: 2, + IdleConn: 120, + }, + }, + } + + pr, err := NewPeer(config.Config{Peers: peers}, testLogger) + if err != nil { + t.Fatal(err) + } + + member, ok := pr.peers["model1"] + if !ok { + t.Fatal("expected model1 to be mapped") + } + + transport, ok := member.reverseProxy.Transport.(*http.Transport) + if !ok { + t.Fatal("expected Transport to be *http.Transport") + } + + if transport.ResponseHeaderTimeout != 300*time.Second { + t.Errorf("expected ResponseHeaderTimeout=%v, got %v", 300*time.Second, transport.ResponseHeaderTimeout) + } + if transport.TLSHandshakeTimeout != 15*time.Second { + t.Errorf("expected TLSHandshakeTimeout=%v, got %v", 15*time.Second, transport.TLSHandshakeTimeout) + } + if transport.ExpectContinueTimeout != 2*time.Second { + t.Errorf("expected ExpectContinueTimeout=%v, got %v", 2*time.Second, transport.ExpectContinueTimeout) + } + if transport.IdleConnTimeout != 120*time.Second { + t.Errorf("expected IdleConnTimeout=%v, got %v", 120*time.Second, transport.IdleConnTimeout) + } + if !transport.ForceAttemptHTTP2 { + t.Error("expected ForceAttemptHTTP2 to be true") + } +} diff --git a/internal/router/router.go b/internal/router/router.go new file mode 100644 index 0000000..4bc8d8c --- /dev/null +++ b/internal/router/router.go @@ -0,0 +1,199 @@ +package router + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/mostlygeek/llama-swap/internal/config" + "github.com/mostlygeek/llama-swap/internal/logmon" + "github.com/mostlygeek/llama-swap/internal/process" + "github.com/tidwall/gjson" +) + +type contextkey struct { + name string +} + +type ReqContextData struct { + Model string + ModelID string + Streaming bool + SendLoadingState bool +} + +var ( + ErrNoModelInContext = fmt.Errorf("no model in request context") + ErrNoRouterFound = fmt.Errorf("no router found for model") + ErrNoPeerModelFound = fmt.Errorf("peer model not found") + ErrNoLocalModelFound = fmt.Errorf("local model not found") + + ContextKey = &contextkey{"context"} +) + +type Router interface { + // Shutdown blocks until the router has shutdown returning nil + // when the router has shutdown successfully. + // + // timeout controls how long to wait for inflight requests to finish. After + // the timeout all inflight requests will be cancelled. + Shutdown(timeout time.Duration) error + + // ServeHTTP implements the http.Handler and requests coming in will + // trigger any model swapping and routing logic. + ServeHTTP(http.ResponseWriter, *http.Request) + + // Handles reports whether this router can serve requests for the given model. + Handles(model string) bool +} + +// LocalRouter is a Router backed by local processes whose state can be +// inspected and which can be individually stopped. Peer routers, which only +// forward to remote hosts, do not implement it. +type LocalRouter interface { + Router + + // RunningModels returns the current state of every process that is not + // stopped or shut down, keyed by model ID. + RunningModels() map[string]process.ProcessState + + // Unload stops the named models, or every running model when none are + // named. It blocks until each targeted process has stopped. + Unload(timeout time.Duration, models ...string) + + // ProcessLogger returns the log monitor for the named model's process. + // modelID must be a real (non-alias) config key. Returns false when the + // model is not known to this router. + ProcessLogger(modelID string) (*logmon.Monitor, bool) +} + +// FetchContext will attempt to get the model id from the context then +// from the model body. If it extracts the model from the body it will +// store the model in the context for downstream handlers. An error +// will be returned when model can not be fetch from either location. +func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) { + data, ok := ReadContext(r.Context()) + if ok { + return data, nil + } + + if data, err := ExtractContext(r); err == nil { + realName, _ := cfg.RealModelName(data.Model) + if realName == "" { + realName = data.Model + } + data.ModelID = realName + if mc, ok := cfg.Models[realName]; ok { + data.SendLoadingState = mc.SendLoadingState != nil && *mc.SendLoadingState + } + *r = *r.WithContext(SetContext(r.Context(), data)) + return data, nil + } + + return ReqContextData{}, ErrNoModelInContext +} + +func SetContext(ctx context.Context, data ReqContextData) context.Context { + return context.WithValue(ctx, ContextKey, data) +} + +func ReadContext(ctx context.Context) (ReqContextData, bool) { + data, ok := ctx.Value(ContextKey).(ReqContextData) + return data, ok +} + +// ExtractContext pulls the model name from an HTTP request without consuming the +// body. For GET requests it reads the "model" query parameter. For POST +// requests it inspects Content-Type and parses JSON, multipart/form-data, or +// application/x-www-form-urlencoded bodies. The request body is always restored +// before returning so downstream handlers — including reverse proxies that +// forward raw bytes upstream — can still read it. +func ExtractContext(r *http.Request) (ReqContextData, error) { + if r.Method == http.MethodGet { + if model := r.URL.Query().Get("model"); model != "" { + return ReqContextData{Model: model, Streaming: r.URL.Query().Get("stream") == "true"}, nil + } + return ReqContextData{}, fmt.Errorf("missing 'model' query parameter") + } + + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + return ReqContextData{}, fmt.Errorf("error reading request body: %w", err) + } + defer func() { + r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + }() + + contentType := r.Header.Get("Content-Type") + + if strings.Contains(contentType, "application/json") { + model := gjson.GetBytes(bodyBytes, "model").String() + if model == "" { + return ReqContextData{}, fmt.Errorf("missing or empty 'model' in JSON body") + } + return ReqContextData{Model: model, Streaming: gjson.GetBytes(bodyBytes, "stream").Bool()}, nil + } + + // Form parsers read from r.Body, so feed them a fresh reader over the + // buffered bytes. The deferred restore above will reset r.Body again + // after parsing. + r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + if strings.Contains(contentType, "multipart/form-data") { + if err := r.ParseMultipartForm(32 << 20); err != nil { + return ReqContextData{}, fmt.Errorf("error parsing multipart form: %w", err) + } + } else { + if err := r.ParseForm(); err != nil { + return ReqContextData{}, fmt.Errorf("error parsing form: %w", err) + } + } + + if model := r.FormValue("model"); model != "" { + return ReqContextData{Model: model, Streaming: r.FormValue("stream") == "true"}, nil + } + + return ReqContextData{}, fmt.Errorf("missing 'model' parameter") +} + +func SendError(w http.ResponseWriter, r *http.Request, err error) { + switch { + case errors.Is(err, ErrNoModelInContext): + SendResponse(w, r, http.StatusNotFound, "no model id could be identified") + case errors.Is(err, ErrNoPeerModelFound): + SendResponse(w, r, http.StatusNotFound, "no peer found for requested model") + case errors.Is(err, ErrNoLocalModelFound): + SendResponse(w, r, http.StatusNotFound, "no local server found for requested model") + case errors.Is(err, ErrNoRouterFound): + SendResponse(w, r, http.StatusNotFound, "no router for requested model") + default: + SendResponse(w, r, http.StatusInternalServerError, fmt.Sprintf("unspecific error: %v", err)) + } +} + +// SendResponse detects what content type the client prefers and returns an error response in that format. +func SendResponse(w http.ResponseWriter, r *http.Request, status int, message string) { + // Check Accept header for preferred response format + acceptHeader := r.Header.Get("Accept") + if strings.Contains(acceptHeader, "text/plain") { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(status) + w.Write([]byte(fmt.Sprintf("llama-swap: %s", message))) + return + } + + if strings.Contains(acceptHeader, "text/html") { + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(status) + w.Write([]byte(fmt.Sprintf(`

llama-swap

%s

`, message))) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + w.Write([]byte(fmt.Sprintf(`{"src":"llama-swap", "error": "%s"}`, message))) +} diff --git a/internal/router/router_test.go b/internal/router/router_test.go new file mode 100644 index 0000000..fa88364 --- /dev/null +++ b/internal/router/router_test.go @@ -0,0 +1,275 @@ +package router + +import ( + "bytes" + "context" + "io" + "mime/multipart" + "net/http" + "net/url" + "strings" + "testing" +) + +func TestExtractContext_GET(t *testing.T) { + tests := []struct { + name string + query string + wantModel string + wantErr bool + }{ + {"model present", "model=llama3", "llama3", false}, + {"model with slashes", "model=author/model-7b", "author/model-7b", false}, + {"model missing", "", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil) + got, err := ExtractContext(r) + if (err != nil) != tt.wantErr { + t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err) + } + if got.Model != tt.wantModel { + t.Errorf("want %q got %q", tt.wantModel, got.Model) + } + }) + } +} + +func TestExtractContext_JSON(t *testing.T) { + tests := []struct { + name string + body string + wantModel string + wantErr bool + }{ + {"model present", `{"model":"llama3","stream":true}`, "llama3", false}, + {"model with slashes", `{"model":"author/model-7b"}`, "author/model-7b", false}, + {"model empty string", `{"model":""}`, "", true}, + {"model key missing", `{"stream":true}`, "", true}, + {"invalid json", `not-json`, "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body)) + r.Header.Set("Content-Type", "application/json") + got, err := ExtractContext(r) + if (err != nil) != tt.wantErr { + t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err) + } + if got.Model != tt.wantModel { + t.Errorf("want %q got %q", tt.wantModel, got.Model) + } + }) + } +} + +func TestExtractContext_URLEncodedForm(t *testing.T) { + tests := []struct { + name string + formModel string + wantModel string + wantErr bool + }{ + {"model present", "whisper-1", "whisper-1", false}, + {"model missing", "", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + form := url.Values{} + if tt.formModel != "" { + form.Set("model", tt.formModel) + } + r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader(form.Encode())) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + got, err := ExtractContext(r) + if (err != nil) != tt.wantErr { + t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err) + } + if got.Model != tt.wantModel { + t.Errorf("want %q got %q", tt.wantModel, got.Model) + } + }) + } +} + +func TestExtractContext_MultipartForm(t *testing.T) { + tests := []struct { + name string + formModel string + wantModel string + wantErr bool + }{ + {"model present", "whisper-1", "whisper-1", false}, + {"model missing", "", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + mw := multipart.NewWriter(&buf) + if tt.formModel != "" { + fw, _ := mw.CreateFormField("model") + fw.Write([]byte(tt.formModel)) + } + mw.Close() + + r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", &buf) + r.Header.Set("Content-Type", mw.FormDataContentType()) + got, err := ExtractContext(r) + if (err != nil) != tt.wantErr { + t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err) + } + if got.Model != tt.wantModel { + t.Errorf("want %q got %q", tt.wantModel, got.Model) + } + }) + } +} + +func TestExtractContext_JSONBodyRestored(t *testing.T) { + body := `{"model":"llama3","stream":true}` + r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(body)) + r.Header.Set("Content-Type", "application/json") + + if _, err := ExtractContext(r); err != nil { + t.Fatalf("ExtractContext: %v", err) + } + + remaining, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("reading body after ExtractContext: %v", err) + } + if string(remaining) != body { + t.Errorf("body not restored: want %q got %q", body, string(remaining)) + } +} + +func TestExtractContext_MultipartBodyRestored(t *testing.T) { + var buf bytes.Buffer + mw := multipart.NewWriter(&buf) + fw, _ := mw.CreateFormField("model") + fw.Write([]byte("whisper-1")) + ff, _ := mw.CreateFormFile("file", "audio.wav") + ff.Write([]byte("fake-audio-bytes")) + mw.Close() + + original := buf.Bytes() + + r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", bytes.NewReader(original)) + r.Header.Set("Content-Type", mw.FormDataContentType()) + + if _, err := ExtractContext(r); err != nil { + t.Fatalf("ExtractContext: %v", err) + } + + remaining, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("reading body after ExtractContext: %v", err) + } + if !bytes.Equal(remaining, original) { + t.Errorf("multipart body not restored: want %d bytes got %d bytes", len(original), len(remaining)) + } +} + +func TestExtractContext_URLEncodedBodyRestored(t *testing.T) { + body := "model=whisper-1&extra=value" + r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader(body)) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + if _, err := ExtractContext(r); err != nil { + t.Fatalf("ExtractContext: %v", err) + } + + remaining, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("reading body after ExtractContext: %v", err) + } + if string(remaining) != body { + t.Errorf("url-encoded body not restored: want %q got %q", body, string(remaining)) + } +} + +func TestSetContext(t *testing.T) { + ctx := SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3"}) + data, ok := ctx.Value(ContextKey).(ReqContextData) + if !ok { + t.Fatalf("ContextKey not set or wrong type") + } + if data.Model != "llama3" { + t.Errorf("want %q got %q", "llama3", data.Model) + } + if data.ModelID != "llama3" { + t.Errorf("want %q got %q", "llama3", data.ModelID) + } +} + +func TestSetContext_WithAlias(t *testing.T) { + ctx := SetContext(context.Background(), ReqContextData{Model: "llama", ModelID: "llama3"}) + data, _ := ctx.Value(ContextKey).(ReqContextData) + if data.Model != "llama" { + t.Errorf("want requested %q got %q", "llama", data.Model) + } + if data.ModelID != "llama3" { + t.Errorf("want real %q got %q", "llama3", data.ModelID) + } +} + +func TestSetContext_DoesNotMutateParent(t *testing.T) { + parent := context.Background() + _ = SetContext(parent, ReqContextData{Model: "llama3", ModelID: "llama3"}) + if v := parent.Value(ContextKey); v != nil { + t.Errorf("parent context was mutated: %v", v) + } +} + +func TestReadContext(t *testing.T) { + tests := []struct { + name string + ctx context.Context + wantReq string + wantReal string + wantBool bool + }{ + { + name: "model present, same name", + ctx: SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3"}), + wantReq: "llama3", + wantReal: "llama3", + wantBool: true, + }, + { + name: "model present, aliased", + ctx: SetContext(context.Background(), ReqContextData{Model: "llama", ModelID: "llama3"}), + wantReq: "llama", + wantReal: "llama3", + wantBool: true, + }, + { + name: "model absent", + ctx: context.Background(), + wantReq: "", + wantReal: "", + wantBool: false, + }, + { + name: "model is empty string", + ctx: SetContext(context.Background(), ReqContextData{Model: "", ModelID: ""}), + wantReq: "", + wantReal: "", + wantBool: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotData, ok := ReadContext(tt.ctx) + if gotData.Model != tt.wantReq || gotData.ModelID != tt.wantReal || ok != tt.wantBool { + t.Errorf("want (%q, %q, %v) got (%q, %q, %v)", tt.wantReq, tt.wantReal, tt.wantBool, gotData.Model, gotData.ModelID, ok) + } + }) + } +} diff --git a/internal/server/api.go b/internal/server/api.go new file mode 100644 index 0000000..b1df34f --- /dev/null +++ b/internal/server/api.go @@ -0,0 +1,266 @@ +package server + +import ( + "encoding/json" + "net/http" + "sort" + "strings" + "time" + + "github.com/mostlygeek/llama-swap/internal/config" + "github.com/mostlygeek/llama-swap/internal/event" + "github.com/mostlygeek/llama-swap/internal/router" + "github.com/mostlygeek/llama-swap/internal/shared" +) + +// modelRecord is one entry in the OpenAI-compatible /v1/models listing. +type modelRecord struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + OwnedBy string `json:"owned_by"` + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + Meta map[string]any `json:"meta,omitempty"` +} + +// handleListModels serves the OpenAI-compatible model listing: local models +// (with optional aliases) plus peer models. +func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) { + created := time.Now().Unix() + data := make([]modelRecord, 0, len(s.cfg.Models)) + + newRecord := func(id, name, description string, metadata map[string]any) modelRecord { + rec := modelRecord{ + ID: id, + Object: "model", + Created: created, + OwnedBy: "llama-swap", + Name: strings.TrimSpace(name), + Description: strings.TrimSpace(description), + } + if len(metadata) > 0 { + rec.Meta = map[string]any{"llamaswap": metadata} + } + return rec + } + + for id, mc := range s.cfg.Models { + if mc.Unlisted { + continue + } + data = append(data, newRecord(id, mc.Name, mc.Description, mc.Metadata)) + + if s.cfg.IncludeAliasesInList { + for _, alias := range mc.Aliases { + if alias := strings.TrimSpace(alias); alias != "" { + data = append(data, newRecord(alias, mc.Name, mc.Description, mc.Metadata)) + } + } + } + } + + for peerID, peer := range s.cfg.Peers { + for _, modelID := range peer.Models { + data = append(data, newRecord(modelID, peerID+": "+modelID, "", map[string]any{"peerID": peerID})) + } + } + + sort.Slice(data, func(i, j int) bool { return data[i].ID < data[j].ID }) + + // Echo the Origin so browser clients can read the listing. + if origin := r.Header.Get("Origin"); origin != "" { + w.Header().Set("Access-Control-Allow-Origin", origin) + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{ + "object": "list", + "data": data, + }) +} + +// runningModel is one entry in the /running listing. +type runningModel struct { + Model string `json:"model"` + State string `json:"state"` + Cmd string `json:"cmd"` + Proxy string `json:"proxy"` + TTL int `json:"ttl"` + Name string `json:"name"` + Description string `json:"description"` +} + +// handleUnload stops every running local process. Peer models are remote and +// unaffected. +func (s *Server) handleUnload(w http.ResponseWriter, r *http.Request) { + s.local.Unload(0) + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) +} + +// handleRunning lists local processes that are not stopped, joining each model +// ID against its config for the cmd/proxy/ttl/name/description metadata. +func (s *Server) handleRunning(w http.ResponseWriter, r *http.Request) { + states := s.local.RunningModels() + list := make([]runningModel, 0, len(states)) + for id, state := range states { + mc := s.cfg.Models[id] + list = append(list, runningModel{ + Model: id, + State: string(state), + Cmd: mc.Cmd, + Proxy: mc.Proxy, + TTL: mc.UnloadAfter, + Name: mc.Name, + Description: mc.Description, + }) + } + sort.Slice(list, func(i, j int) bool { return list[i].Model < list[j].Model }) + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{"running": list}) +} + +// discardResponseWriter satisfies http.ResponseWriter for preload requests, +// dropping the body while capturing the status code. +type discardResponseWriter struct { + header http.Header + status int +} + +func (d *discardResponseWriter) Header() http.Header { + if d.header == nil { + d.header = make(http.Header) + } + return d.header +} + +func (d *discardResponseWriter) Write(p []byte) (int, error) { return len(p), nil } + +func (d *discardResponseWriter) WriteHeader(status int) { d.status = status } + +// startPreload fires a background GET / at every model named in +// Hooks.OnStartup.Preload so they are warm before the first real request. +// Preload names are already resolved to real model IDs by config loading. +func (s *Server) startPreload() { + models := s.cfg.Hooks.OnStartup.Preload + if len(models) == 0 { + return + } + go func() { + for _, modelID := range models { + if !s.local.Handles(modelID) { + s.proxylog.Warnf("preload: model %s is not a local model, skipping", modelID) + continue + } + s.proxylog.Infof("preloading model: %s", modelID) + + req, err := http.NewRequestWithContext(s.shutdownCtx, http.MethodGet, "/", nil) + if err != nil { + continue + } + req = req.WithContext(router.SetContext(req.Context(), router.ReqContextData{Model: modelID, ModelID: modelID})) + + dw := &discardResponseWriter{status: http.StatusOK} + s.local.ServeHTTP(dw, req) + + success := dw.status < http.StatusBadRequest + if !success { + s.proxylog.Errorf("failed to preload model %s: status %d", modelID, dw.status) + } + event.Emit(shared.ModelPreloadedEvent{ModelName: modelID, Success: success}) + } + }() +} + +// handleMetrics serves Prometheus-format performance metrics. Returns 503 when +// performance monitoring is disabled. +func (s *Server) handleMetrics(w http.ResponseWriter, r *http.Request) { + if s.perf == nil { + w.WriteHeader(http.StatusServiceUnavailable) + w.Write([]byte("# performance monitor not available\n")) + return + } + s.perf.MetricsHandler().ServeHTTP(w, r) +} + +func handleHealth(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) +} + +func handleRootRedirect(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "/ui", http.StatusFound) +} + +func handleUpstreamRedirect(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "/ui/models", http.StatusFound) +} + +// handleUpstream proxies ANY request under /upstream// directly to +// the model's process, bypassing model dispatch by body/query inspection. +func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) { + upstreamPath := r.PathValue("upstreamPath") + + searchName, modelID, remainingPath, found := findModelInPath(s.cfg, "/"+upstreamPath) + if !found { + router.SendResponse(w, r, http.StatusNotFound, "model not found") + return + } + + // Redirect /upstream/model to /upstream/model/ so relative URLs in upstream + // responses resolve. 301 for GET/HEAD, 308 otherwise to preserve the method. + if remainingPath == "/" && !strings.HasSuffix(r.URL.Path, "/") { + newPath := "/upstream/" + searchName + "/" + if r.URL.RawQuery != "" { + newPath += "?" + r.URL.RawQuery + } + if r.Method == http.MethodGet || r.Method == http.MethodHead { + http.Redirect(w, r, newPath, http.StatusMovedPermanently) + } else { + http.Redirect(w, r, newPath, http.StatusPermanentRedirect) + } + return + } + + // Strip the /upstream/ prefix before forwarding. + r.URL.Path = remainingPath + // Pin the resolved model so the router skips body/query extraction. + *r = *r.WithContext(router.SetContext(r.Context(), router.ReqContextData{Model: searchName, ModelID: modelID})) + + switch { + case s.local.Handles(modelID): + s.local.ServeHTTP(w, r) + case s.peer.Handles(modelID): + s.peer.ServeHTTP(w, r) + default: + router.SendResponse(w, r, http.StatusNotFound, "no router for model "+modelID) + } +} + +// findModelInPath walks a slash-separated path, building up segments until one +// matches a configured model. This resolves model names that contain slashes +// (e.g. "author/model"). Returns the matched name, its real model ID, the +// remaining path, and whether a match was found. +func findModelInPath(cfg config.Config, path string) (searchName, realName, remainingPath string, found bool) { + parts := strings.Split(strings.TrimSpace(path), "/") + name := "" + + for i, part := range parts { + if part == "" { + continue + } + if name == "" { + name = part + } else { + name = name + "/" + part + } + + if modelID, ok := cfg.RealModelName(name); ok { + return name, modelID, "/" + strings.Join(parts[i+1:], "/"), true + } + } + + return "", "", "", false +} diff --git a/internal/server/api_test.go b/internal/server/api_test.go new file mode 100644 index 0000000..fe1967f --- /dev/null +++ b/internal/server/api_test.go @@ -0,0 +1,159 @@ +package server + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/mostlygeek/llama-swap/internal/config" +) + +func TestServer_HandleListModels(t *testing.T) { + s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, "")) + s.cfg = config.Config{ + Models: map[string]config.ModelConfig{ + "visible": {Name: "Visible", Description: "a model"}, + "hidden": {Unlisted: true}, + }, + Peers: config.PeerDictionaryConfig{ + "peer1": {Models: []string{"remote-model"}}, + }, + } + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/v1/models", nil) + req.Header.Set("Origin", "http://example.com") + s.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status = %d", w.Code) + } + if got := w.Header().Get("Access-Control-Allow-Origin"); got != "http://example.com" { + t.Errorf("Access-Control-Allow-Origin = %q", got) + } + + var resp struct { + Data []modelRecord `json:"data"` + } + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("decode: %v", err) + } + ids := map[string]bool{} + for _, m := range resp.Data { + ids[m.ID] = true + } + if !ids["visible"] || !ids["remote-model"] { + t.Errorf("missing expected models: %v", ids) + } + if ids["hidden"] { + t.Error("unlisted model should not appear") + } +} + +func TestServer_HandleListModels_Aliases(t *testing.T) { + s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, "")) + s.cfg = config.Config{ + IncludeAliasesInList: true, + Models: map[string]config.ModelConfig{ + "real": {Aliases: []string{"nick"}}, + }, + } + + w := httptest.NewRecorder() + s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/v1/models", nil)) + + var resp struct { + Data []modelRecord `json:"data"` + } + json.Unmarshal(w.Body.Bytes(), &resp) + ids := map[string]bool{} + for _, m := range resp.Data { + ids[m.ID] = true + } + if !ids["real"] || !ids["nick"] { + t.Errorf("expected alias entry; got %v", ids) + } +} + +func TestServer_FindModelInPath(t *testing.T) { + cfg := config.Config{Models: map[string]config.ModelConfig{ + "author/model": {}, + "simple": {}, + }} + + cases := []struct { + path string + wantName string + wantRem string + wantFound bool + }{ + {"/simple/v1/chat", "simple", "/v1/chat", true}, + {"/author/model/v1/chat", "author/model", "/v1/chat", true}, + {"/author/model", "author/model", "/", true}, + {"/missing/v1", "", "", false}, + {"/", "", "", false}, + } + for _, c := range cases { + name, _, rem, found := findModelInPath(cfg, c.path) + if found != c.wantFound || name != c.wantName || (found && rem != c.wantRem) { + t.Errorf("findModelInPath(%q) = (%q,%q,%v), want (%q,%q,%v)", + c.path, name, rem, found, c.wantName, c.wantRem, c.wantFound) + } + } +} + +func TestServer_HandleUpstream(t *testing.T) { + local := newStubRouter([]string{"m1"}, "upstream-body") + s := newTestServer(local, newStubRouter(nil, "")) + s.cfg = config.Config{Models: map[string]config.ModelConfig{"m1": {}}} + + t.Run("proxies to local", func(t *testing.T) { + w := httptest.NewRecorder() + s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/v1/chat", nil)) + if w.Code != http.StatusOK || w.Body.String() != "upstream-body" { + t.Errorf("status=%d body=%q", w.Code, w.Body.String()) + } + }) + + t.Run("redirects bare model path", func(t *testing.T) { + w := httptest.NewRecorder() + s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1", nil)) + if w.Code != http.StatusMovedPermanently { + t.Errorf("status = %d, want 301", w.Code) + } + }) + + t.Run("unknown model 404", func(t *testing.T) { + w := httptest.NewRecorder() + s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/nope/v1", nil)) + if w.Code != http.StatusNotFound { + t.Errorf("status = %d, want 404", w.Code) + } + }) +} + +func TestServer_HandleMetrics_Unavailable(t *testing.T) { + s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, "")) + + w := httptest.NewRecorder() + s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/metrics", nil)) + if w.Code != http.StatusServiceUnavailable { + t.Errorf("status = %d, want 503", w.Code) + } +} + +func TestServer_Redirects(t *testing.T) { + s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, "")) + + for path, want := range map[string]string{"/": "/ui", "/upstream": "/ui/models"} { + w := httptest.NewRecorder() + s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, path, nil)) + if w.Code != http.StatusFound { + t.Errorf("%s: status = %d, want 302", path, w.Code) + } + if got := w.Header().Get("Location"); got != want { + t.Errorf("%s: Location = %q, want %q", path, got, want) + } + } +} diff --git a/internal/server/apigroup.go b/internal/server/apigroup.go new file mode 100644 index 0000000..e71b29c --- /dev/null +++ b/internal/server/apigroup.go @@ -0,0 +1,270 @@ +package server + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "sort" + "strconv" + "strings" + "time" + + "github.com/mostlygeek/llama-swap/internal/event" + "github.com/mostlygeek/llama-swap/internal/perf" + "github.com/mostlygeek/llama-swap/internal/router" + "github.com/mostlygeek/llama-swap/internal/shared" +) + +// apiModel is one entry in the /api/events modelStatus payload. +type apiModel struct { + Id string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + State string `json:"state"` + Unlisted bool `json:"unlisted"` + PeerID string `json:"peerID"` + Aliases []string `json:"aliases,omitempty"` +} + +// modelStatus returns every configured model joined with its current process +// state (defaulting to "stopped"), followed by peer models. +func (s *Server) modelStatus() []apiModel { + running := s.local.RunningModels() + + ids := make([]string, 0, len(s.cfg.Models)) + for id := range s.cfg.Models { + ids = append(ids, id) + } + sort.Strings(ids) + + models := make([]apiModel, 0, len(ids)) + for _, id := range ids { + mc := s.cfg.Models[id] + state := "stopped" + if st, ok := running[id]; ok { + state = string(st) + } + models = append(models, apiModel{ + Id: id, + Name: mc.Name, + Description: mc.Description, + State: state, + Unlisted: mc.Unlisted, + Aliases: mc.Aliases, + }) + } + + for peerID, peer := range s.cfg.Peers { + for _, modelID := range peer.Models { + models = append(models, apiModel{Id: modelID, PeerID: peerID}) + } + } + + return models +} + +// handleAPIUnloadAll stops every running local process. +func (s *Server) handleAPIUnloadAll(w http.ResponseWriter, r *http.Request) { + s.local.Unload(0) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{"msg": "ok"}) +} + +// handleAPIUnloadModel stops a single named local process. +func (s *Server) handleAPIUnloadModel(w http.ResponseWriter, r *http.Request) { + requested := strings.TrimPrefix(r.PathValue("model"), "/") + realName, found := s.cfg.RealModelName(requested) + if !found { + router.SendResponse(w, r, http.StatusNotFound, "model not found") + return + } + if !s.local.Handles(realName) { + router.SendResponse(w, r, http.StatusNotFound, "no local server found for requested model") + return + } + s.local.Unload(0, realName) + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) +} + +// handleAPIMetrics serves the activity log as a JSON array. +func (s *Server) handleAPIMetrics(w http.ResponseWriter, r *http.Request) { + data, err := s.metrics.getMetricsJSON() + if err != nil { + router.SendResponse(w, r, http.StatusInternalServerError, "failed to get metrics") + return + } + w.Header().Set("Content-Type", "application/json") + w.Write(data) +} + +// handleAPIPerformance serves the buffered system/GPU stats, optionally +// filtered to samples after the ?after= timestamp. +func (s *Server) handleAPIPerformance(w http.ResponseWriter, r *http.Request) { + if s.perf == nil { + router.SendResponse(w, r, http.StatusServiceUnavailable, "performance monitor not available") + return + } + + sysStats, gpuStats := s.perf.Current() + + if afterStr := r.URL.Query().Get("after"); afterStr != "" { + after, err := time.Parse(time.RFC3339, afterStr) + if err != nil { + router.SendResponse(w, r, http.StatusBadRequest, "invalid 'after' timestamp, use RFC3339 format") + return + } + filteredSys := make([]perf.SysStat, 0, len(sysStats)) + for _, st := range sysStats { + if st.Timestamp.After(after) { + filteredSys = append(filteredSys, st) + } + } + sysStats = filteredSys + + filteredGpu := make([]perf.GpuStat, 0, len(gpuStats)) + for _, g := range gpuStats { + if g.Timestamp.After(after) { + filteredGpu = append(filteredGpu, g) + } + } + gpuStats = filteredGpu + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{ + "sys_stats": sysStats, + "gpu_stats": gpuStats, + }) +} + +// handleAPIVersion serves the build metadata. +func (s *Server) handleAPIVersion(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{ + "version": s.build.Version, + "commit": s.build.Commit, + "build_date": s.build.Date, + }) +} + +// handleAPICapture returns the stored request/response capture for a metric ID. +func (s *Server) handleAPICapture(w http.ResponseWriter, r *http.Request) { + id, err := strconv.Atoi(r.PathValue("id")) + if err != nil { + router.SendResponse(w, r, http.StatusBadRequest, "invalid capture ID") + return + } + + capture := s.metrics.getCaptureByID(id) + if capture == nil { + router.SendResponse(w, r, http.StatusNotFound, "capture not found") + return + } + + jsonBytes, err := json.Marshal(capture) + if err != nil { + router.SendResponse(w, r, http.StatusInternalServerError, "failed to marshal capture") + return + } + w.Header().Set("Content-Type", "application/json") + w.Write(jsonBytes) +} + +type messageType string + +const ( + msgTypeModelStatus messageType = "modelStatus" + msgTypeLogData messageType = "logData" + msgTypeMetrics messageType = "metrics" + msgTypeInFlight messageType = "inflight" +) + +type messageEnvelope struct { + Type messageType `json:"type"` + Data string `json:"data"` +} + +// handleAPIEvents streams server events (model status, log data, metrics, +// in-flight counts) to the client as Server-Sent Events. +func (s *Server) handleAPIEvents(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Content-Type-Options", "nosniff") + // prevent nginx from buffering SSE + w.Header().Set("X-Accel-Buffering", "no") + + flusher, ok := w.(http.Flusher) + if !ok { + router.SendResponse(w, r, http.StatusInternalServerError, "streaming unsupported") + return + } + + // internal/event already has a 50K event buffer + // a 1K message buffer should be enough, watch the logs for the warning that the sendBuffer is full + sendBuffer := make(chan messageEnvelope, 1024) + ctx, cancel := context.WithCancel(r.Context()) + defer cancel() + + send := func(msg messageEnvelope) { + select { + case sendBuffer <- msg: + case <-ctx.Done(): + s.proxylog.Warn("handleAPIEvents send suppressed due to context done") + default: + s.proxylog.Warn("handleAPIEvents sendBuffer full, dropped message") + } + } + sendModels := func() { + if data, err := json.Marshal(s.modelStatus()); err == nil { + send(messageEnvelope{Type: msgTypeModelStatus, Data: string(data)}) + } + } + sendLogData := func(source string, data []byte) { + if j, err := json.Marshal(map[string]string{"source": source, "data": string(data)}); err == nil { + send(messageEnvelope{Type: msgTypeLogData, Data: string(j)}) + } + } + sendMetrics := func(metrics []ActivityLogEntry) { + if j, err := json.Marshal(metrics); err == nil { + send(messageEnvelope{Type: msgTypeMetrics, Data: string(j)}) + } + } + sendInFlight := func(total int) { + if j, err := json.Marshal(map[string]int{"total": total}); err == nil { + send(messageEnvelope{Type: msgTypeInFlight, Data: string(j)}) + } + } + + defer event.On(func(e shared.ProcessStateChangeEvent) { sendModels() })() + defer event.On(func(e shared.ConfigFileChangedEvent) { sendModels() })() + defer s.proxylog.OnLogData(func(data []byte) { sendLogData("proxy", data) })() + defer s.upstreamlog.OnLogData(func(data []byte) { sendLogData("upstream", data) })() + defer event.On(func(e ActivityLogEvent) { sendMetrics([]ActivityLogEntry{e.Metrics}) })() + defer event.On(func(e shared.InFlightRequestsEvent) { sendInFlight(e.Total) })() + + // initial payload + sendLogData("proxy", s.proxylog.GetHistory()) + sendLogData("upstream", s.upstreamlog.GetHistory()) + sendModels() + sendMetrics(s.metrics.getMetrics()) + sendInFlight(int(s.inflight.Current())) + + for { + select { + case <-r.Context().Done(): + return + case <-s.shutdownCtx.Done(): + return + case msg := <-sendBuffer: + data, err := json.Marshal(msg) + if err != nil { + continue + } + fmt.Fprintf(w, "event:message\ndata:%s\n\n", data) + flusher.Flush() + } + } +} diff --git a/internal/server/apigroup_test.go b/internal/server/apigroup_test.go new file mode 100644 index 0000000..a8fa664 --- /dev/null +++ b/internal/server/apigroup_test.go @@ -0,0 +1,103 @@ +package server + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestServer_InflightMiddleware(t *testing.T) { + c := &inflightCounter{} + mw := CreateInflightMiddleware(c) + + var duringRequest int64 + handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + duringRequest = c.Current() + })) + + handler.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)) + + if duringRequest != 1 { + t.Errorf("counter during request = %d, want 1", duringRequest) + } + if got := c.Current(); got != 0 { + t.Errorf("counter after request = %d, want 0", got) + } +} + +func TestServer_APIVersion(t *testing.T) { + s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, "")) + s.build = BuildInfo{Version: "1.2.3", Commit: "deadbeef", Date: "2026-05-19"} + + w := httptest.NewRecorder() + s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/version", nil)) + + if w.Code != http.StatusOK { + t.Fatalf("status = %d", w.Code) + } + var got map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if got["version"] != "1.2.3" || got["commit"] != "deadbeef" || got["build_date"] != "2026-05-19" { + t.Errorf("body = %v", got) + } +} + +func TestServer_APIMetrics_Empty(t *testing.T) { + s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, "")) + + w := httptest.NewRecorder() + s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/metrics", nil)) + + if w.Code != http.StatusOK { + t.Fatalf("status = %d", w.Code) + } + if body := strings.TrimSpace(w.Body.String()); body != "[]" { + t.Errorf("body = %q, want []", body) + } +} + +func TestServer_APIPerformance_Unavailable(t *testing.T) { + s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, "")) + + w := httptest.NewRecorder() + s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/performance", nil)) + + if w.Code != http.StatusServiceUnavailable { + t.Errorf("status = %d, want 503", w.Code) + } +} + +func TestServer_APIEvents_InitialPayload(t *testing.T) { + s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, "")) + + ctx, cancel := context.WithCancel(context.Background()) + req := httptest.NewRequest(http.MethodGet, "/api/events", nil).WithContext(ctx) + w := httptest.NewRecorder() + + done := make(chan struct{}) + go func() { + s.ServeHTTP(w, req) + close(done) + }() + + time.Sleep(100 * time.Millisecond) + cancel() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("handler did not return after context cancel") + } + + body := w.Body.String() + for _, want := range []string{`"type":"modelStatus"`, `"type":"inflight"`, `"type":"logData"`} { + if !strings.Contains(body, want) { + t.Errorf("initial SSE payload missing %s; body=%q", want, body) + } + } +} diff --git a/internal/server/auth.go b/internal/server/auth.go new file mode 100644 index 0000000..e385b73 --- /dev/null +++ b/internal/server/auth.go @@ -0,0 +1,135 @@ +package server + +import ( + "encoding/base64" + "net/http" + "strings" + + "github.com/mostlygeek/llama-swap/internal/chain" + "github.com/mostlygeek/llama-swap/internal/config" + "github.com/mostlygeek/llama-swap/internal/router" +) + +// CreateAuthMiddleware returns middleware that validates API keys when the +// config declares any. It accepts the key via Authorization: Bearer, +// Authorization: Basic (password field), or x-api-key. On success the auth +// headers are stripped so they never leak to upstream. When no keys are +// configured the middleware is a pass-through. +func CreateAuthMiddleware(cfg config.Config) chain.Middleware { + keys := cfg.RequiredAPIKeys + return func(next http.Handler) http.Handler { + if len(keys) == 0 { + return next + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + provided := extractAPIKey(r) + + valid := false + for _, key := range keys { + if provided == key { + valid = true + break + } + } + if !valid { + w.Header().Set("WWW-Authenticate", `Basic realm="llama-swap"`) + router.SendResponse(w, r, http.StatusUnauthorized, "unauthorized: invalid or missing API key") + return + } + + r.Header.Del("Authorization") + r.Header.Del("x-api-key") + next.ServeHTTP(w, r) + }) + } +} + +// extractAPIKey pulls a candidate API key from the request, preferring Basic, +// then Bearer, then x-api-key. +func extractAPIKey(r *http.Request) string { + var bearerKey, basicKey string + if auth := r.Header.Get("Authorization"); auth != "" { + if strings.HasPrefix(auth, "Bearer ") { + bearerKey = strings.TrimPrefix(auth, "Bearer ") + } else if strings.HasPrefix(auth, "Basic ") { + encoded := strings.TrimPrefix(auth, "Basic ") + if decoded, err := base64.StdEncoding.DecodeString(encoded); err == nil { + if parts := strings.SplitN(string(decoded), ":", 2); len(parts) == 2 { + basicKey = parts[1] // password field is the API key + } + } + } + } + + switch { + case basicKey != "": + return basicKey + case bearerKey != "": + return bearerKey + default: + return r.Header.Get("x-api-key") + } +} + +// CreateCORSMiddleware returns middleware that answers OPTIONS preflight +// requests with permissive CORS headers (see issues #81, #77, #42). Non-OPTIONS +// requests pass through untouched. +func CreateCORSMiddleware() chain.Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodOptions { + next.ServeHTTP(w, r) + return + } + + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") + if headers := r.Header.Get("Access-Control-Request-Headers"); headers != "" { + w.Header().Set("Access-Control-Allow-Headers", sanitizeAccessControlRequestHeaderValues(headers)) + } else { + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, Accept, X-Requested-With") + } + w.Header().Set("Access-Control-Max-Age", "86400") + w.WriteHeader(http.StatusNoContent) + }) + } +} + +func isTokenChar(r rune) bool { + switch { + case r >= 'a' && r <= 'z': + case r >= 'A' && r <= 'Z': + case r >= '0' && r <= '9': + case strings.ContainsRune("!#$%&'*+-.^_`|~", r): + default: + return false + } + return true +} + +// sanitizeAccessControlRequestHeaderValues drops any header names that contain +// characters outside the HTTP token grammar before echoing them back. +func sanitizeAccessControlRequestHeaderValues(headerValues string) string { + parts := strings.Split(headerValues, ",") + valid := make([]string, 0, len(parts)) + + for _, p := range parts { + v := strings.TrimSpace(p) + if v == "" { + continue + } + + validPart := true + for _, c := range v { + if !isTokenChar(c) { + validPart = false + break + } + } + if validPart { + valid = append(valid, v) + } + } + + return strings.Join(valid, ", ") +} diff --git a/internal/server/auth_test.go b/internal/server/auth_test.go new file mode 100644 index 0000000..e722e4a --- /dev/null +++ b/internal/server/auth_test.go @@ -0,0 +1,120 @@ +package server + +import ( + "encoding/base64" + "net/http" + "net/http/httptest" + "testing" + + "github.com/mostlygeek/llama-swap/internal/config" +) + +func TestServer_ExtractAPIKey(t *testing.T) { + basicHeader := func(user, pass string) string { + return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass)) + } + cases := []struct { + name string + auth string + xapi string + want string + }{ + {"none", "", "", ""}, + {"bearer", "Bearer tok123", "", "tok123"}, + {"basic", basicHeader("user", "pw-key"), "", "pw-key"}, + {"x-api-key", "", "xkey", "xkey"}, + {"basic beats bearer", basicHeader("u", "bk"), "", "bk"}, + {"bearer beats x-api-key", "Bearer btok", "xkey", "btok"}, + {"malformed basic falls back to x-api-key", "Basic !!!notbase64", "xkey", "xkey"}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/", nil) + if c.auth != "" { + r.Header.Set("Authorization", c.auth) + } + if c.xapi != "" { + r.Header.Set("x-api-key", c.xapi) + } + if got := extractAPIKey(r); got != c.want { + t.Errorf("extractAPIKey() = %q, want %q", got, c.want) + } + }) + } +} + +func TestServer_SanitizeAccessControlRequestHeaders(t *testing.T) { + cases := []struct { + in string + want string + }{ + {"Content-Type, Authorization", "Content-Type, Authorization"}, + {" X-Custom , Accept ", "X-Custom, Accept"}, + {"Valid, Bad Header", "Valid"}, + {"Bad@Header", ""}, + {"", ""}, + } + for _, c := range cases { + if got := sanitizeAccessControlRequestHeaderValues(c.in); got != c.want { + t.Errorf("sanitize(%q) = %q, want %q", c.in, got, c.want) + } + } +} + +func TestServer_IsTokenChar(t *testing.T) { + for _, r := range "abcXYZ0129!#$%&'*+-.^_`|~" { + if !isTokenChar(r) { + t.Errorf("isTokenChar(%q) = false, want true", r) + } + } + for _, r := range " @()/\t\"" { + if isTokenChar(r) { + t.Errorf("isTokenChar(%q) = true, want false", r) + } + } +} + +func TestServer_AuthMiddleware(t *testing.T) { + final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "" || r.Header.Get("x-api-key") != "" { + t.Error("auth headers leaked to upstream") + } + w.WriteHeader(http.StatusOK) + }) + + t.Run("no keys configured passes through", func(t *testing.T) { + mw := CreateAuthMiddleware(config.Config{}) + w := httptest.NewRecorder() + mw(final).ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/", nil)) + if w.Code != http.StatusOK { + t.Errorf("status = %d, want 200", w.Code) + } + }) + + cfg := config.Config{RequiredAPIKeys: []string{"secret"}} + + t.Run("valid key", func(t *testing.T) { + mw := CreateAuthMiddleware(cfg) + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.Header.Set("Authorization", "Bearer secret") + w := httptest.NewRecorder() + mw(final).ServeHTTP(w, r) + if w.Code != http.StatusOK { + t.Errorf("status = %d, want 200", w.Code) + } + }) + + t.Run("invalid key", func(t *testing.T) { + mw := CreateAuthMiddleware(cfg) + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.Header.Set("Authorization", "Bearer wrong") + w := httptest.NewRecorder() + mw(final).ServeHTTP(w, r) + if w.Code != http.StatusUnauthorized { + t.Errorf("status = %d, want 401", w.Code) + } + if w.Header().Get("WWW-Authenticate") == "" { + t.Error("missing WWW-Authenticate header") + } + }) +} diff --git a/internal/server/captures.go b/internal/server/captures.go new file mode 100644 index 0000000..fefab9a --- /dev/null +++ b/internal/server/captures.go @@ -0,0 +1,176 @@ +package server + +import ( + "fmt" + "net/http" + "strings" + "sync" + + "github.com/fxamacker/cbor/v2" + "github.com/klauspost/compress/zstd" +) + +// ReqRespCapture is a stored request/response pair for a single metered request. +type ReqRespCapture struct { + ID int `json:"id"` + ReqPath string `json:"req_path"` + ReqHeaders map[string]string `json:"req_headers"` + ReqBody []byte `json:"req_body"` + RespHeaders map[string]string `json:"resp_headers"` + RespBody []byte `json:"resp_body"` +} + +// captureFields is a bitmask controlling what a route stores in a ReqRespCapture. +type captureFields uint + +const ( + captureReqHeaders captureFields = 1 << iota + captureReqBody + captureRespHeaders + captureRespBody +) + +const ( + captureReqAll = captureReqHeaders | captureReqBody + captureRespAll = captureRespHeaders | captureRespBody + captureAll = captureReqAll | captureRespAll +) + +// captureFieldsByPath overrides the default capture mask for routes carrying +// large binary payloads (audio/image) where storing the full body is wasteful. +var captureFieldsByPath = map[string]captureFields{ + "/v1/audio/speech": captureReqAll | captureRespHeaders, + "/v1/audio/voices": captureReqHeaders | captureRespAll, + "/v1/audio/transcriptions": captureReqHeaders | captureRespHeaders | captureRespBody, + "/v1/images/generations": captureReqAll | captureRespHeaders, + "/v1/images/edits": captureReqHeaders | captureRespHeaders, + "/sdapi/v1/txt2img": captureReqAll | captureRespHeaders, + "/sdapi/v1/img2img": captureReqHeaders | captureRespHeaders, +} + +// captureFieldsFor returns the capture mask for a request path. Unlisted routes +// (the OpenAI-compatible JSON endpoints) capture everything. +func captureFieldsFor(path string) captureFields { + if cf, ok := captureFieldsByPath[path]; ok { + return cf + } + return captureAll +} + +// zstdEncOptions are the shared zstd encoder options for maximum compression. +var zstdEncOptions = []zstd.EOption{ + zstd.WithEncoderLevel(zstd.SpeedBetterCompression), +} + +// zstdEncPool pools zstd.Encoder instances to reduce allocations. +var zstdEncPool = &sync.Pool{ + New: func() interface{} { + enc, _ := zstd.NewWriter(nil, zstdEncOptions...) + return enc + }, +} + +// zstdDecPool pools zstd.Decoder instances to reduce allocations. +var zstdDecPool = &sync.Pool{ + New: func() interface{} { + dec, _ := zstd.NewReader(nil) + return dec + }, +} + +// compressCapture marshals a ReqRespCapture to CBOR and compresses it with zstd. +// Returns the compressed bytes and the original CBOR byte count for logging. +func compressCapture(c *ReqRespCapture) ([]byte, int, error) { + cborBytes, err := cbor.Marshal(c) + if err != nil { + return nil, 0, fmt.Errorf("marshal capture: %w", err) + } + zenc := zstdEncPool.Get().(*zstd.Encoder) + defer zstdEncPool.Put(zenc) + return zenc.EncodeAll(cborBytes, nil), len(cborBytes), nil +} + +// decompressCapture decompresses zstd-compressed CBOR into a ReqRespCapture. +func decompressCapture(data []byte) (*ReqRespCapture, error) { + dec := zstdDecPool.Get().(*zstd.Decoder) + defer zstdDecPool.Put(dec) + cborBytes, err := dec.DecodeAll(data, nil) + if err != nil { + return nil, fmt.Errorf("decompress capture: %w", err) + } + var capture ReqRespCapture + if err := cbor.Unmarshal(cborBytes, &capture); err != nil { + return nil, fmt.Errorf("unmarshal capture: %w", err) + } + return &capture, nil +} + +// addCapture compresses and stores a capture in the cache. Returns true if the +// capture was stored. +func (mp *metricsMonitor) addCapture(capture ReqRespCapture) bool { + if !mp.enableCaptures { + return false + } + + compressed, uncompressedBytes, err := compressCapture(&capture) + if err != nil { + mp.logger.Warnf("failed to compress capture: %v, skipping", err) + return false + } + + if err := mp.captureCache.Add(capture.ID, compressed); err != nil { + mp.logger.Warnf("capture %d too large (%d bytes), skipping: %v", capture.ID, len(compressed), err) + return false + } + + compressionRatio := (1 - float64(len(compressed))/float64(uncompressedBytes)) * 100 + mp.logger.Debugf("Capture %d compressed and saved: %d bytes -> %d bytes (%.1f%% compression)", capture.ID, uncompressedBytes, len(compressed), compressionRatio) + return true +} + +// getCaptureByID decompresses and unmarshals a capture by ID. Returns nil if +// the capture is not found or decompression fails. +func (mp *metricsMonitor) getCaptureByID(id int) *ReqRespCapture { + if mp.captureCache == nil { + return nil + } + data, err := mp.captureCache.Get(id) + if err != nil { + return nil + } + capture, err := decompressCapture(data) + if err != nil { + mp.logger.Warnf("failed to decompress capture %d: %v", id, err) + return nil + } + return capture +} + +// sensitiveHeaders lists headers that are redacted in captures. +var sensitiveHeaders = map[string]bool{ + "authorization": true, + "proxy-authorization": true, + "cookie": true, + "set-cookie": true, + "x-api-key": true, +} + +// headerMap flattens an http.Header to a single-value map. +func headerMap(h http.Header) map[string]string { + m := make(map[string]string, len(h)) + for key, values := range h { + if len(values) > 0 { + m[key] = values[0] + } + } + return m +} + +// redactHeaders replaces sensitive header values in-place with "[REDACTED]". +func redactHeaders(headers map[string]string) { + for key := range headers { + if sensitiveHeaders[strings.ToLower(key)] { + headers[key] = "[REDACTED]" + } + } +} diff --git a/internal/server/captures_test.go b/internal/server/captures_test.go new file mode 100644 index 0000000..5a89c92 --- /dev/null +++ b/internal/server/captures_test.go @@ -0,0 +1,79 @@ +package server + +import ( + "bytes" + "io" + "testing" + + "github.com/mostlygeek/llama-swap/internal/logmon" +) + +func TestServer_CaptureCompressRoundtrip(t *testing.T) { + orig := &ReqRespCapture{ + ID: 7, + ReqPath: "/v1/chat/completions", + ReqHeaders: map[string]string{"Content-Type": "application/json"}, + ReqBody: []byte(`{"model":"m"}`), + RespHeaders: map[string]string{"Content-Type": "application/json"}, + RespBody: []byte(`{"usage":{}}`), + } + + compressed, uncompressed, err := compressCapture(orig) + if err != nil { + t.Fatalf("compressCapture: %v", err) + } + if uncompressed == 0 || len(compressed) == 0 { + t.Fatalf("unexpected sizes: uncompressed=%d compressed=%d", uncompressed, len(compressed)) + } + + got, err := decompressCapture(compressed) + if err != nil { + t.Fatalf("decompressCapture: %v", err) + } + if got.ID != orig.ID || got.ReqPath != orig.ReqPath || + !bytes.Equal(got.ReqBody, orig.ReqBody) || !bytes.Equal(got.RespBody, orig.RespBody) { + t.Fatalf("roundtrip mismatch: %+v", got) + } +} + +func TestServer_CaptureStoreAndRetrieve(t *testing.T) { + mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 100, 5) + if !mm.enableCaptures { + t.Fatal("captures should be enabled with non-zero buffer") + } + + capture := ReqRespCapture{ID: 3, ReqPath: "/v1/chat/completions", ReqBody: []byte("hello")} + if !mm.addCapture(capture) { + t.Fatal("addCapture returned false") + } + + got := mm.getCaptureByID(3) + if got == nil || !bytes.Equal(got.ReqBody, []byte("hello")) { + t.Fatalf("getCaptureByID = %+v", got) + } + if mm.getCaptureByID(999) != nil { + t.Fatal("expected nil for unknown capture ID") + } +} + +func TestServer_CaptureDisabled(t *testing.T) { + mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 100, 0) + if mm.enableCaptures { + t.Fatal("captures should be disabled with zero buffer") + } + if mm.addCapture(ReqRespCapture{ID: 1}) { + t.Fatal("addCapture should return false when disabled") + } + if mm.getCaptureByID(1) != nil { + t.Fatal("getCaptureByID should return nil when disabled") + } +} + +func TestServer_CaptureFieldsFor(t *testing.T) { + if got := captureFieldsFor("/v1/chat/completions"); got != captureAll { + t.Fatalf("default = %b, want captureAll", got) + } + if got := captureFieldsFor("/v1/audio/speech"); got != captureReqAll|captureRespHeaders { + t.Fatalf("/v1/audio/speech = %b", got) + } +} diff --git a/internal/server/concurrency.go b/internal/server/concurrency.go new file mode 100644 index 0000000..ea00c3a --- /dev/null +++ b/internal/server/concurrency.go @@ -0,0 +1,55 @@ +package server + +import ( + "net/http" + + "golang.org/x/sync/semaphore" + + "github.com/mostlygeek/llama-swap/internal/chain" + "github.com/mostlygeek/llama-swap/internal/config" + "github.com/mostlygeek/llama-swap/internal/router" +) + +// defaultConcurrencyLimit caps simultaneous in-flight requests per model when +// the model config leaves concurrencyLimit unset. Matches the legacy +// proxy.Process default. +const defaultConcurrencyLimit = 10 + +// CreateConcurrencyMiddleware returns middleware that limits simultaneous +// model-dispatched requests per model. Each model gets a semaphore sized to +// its concurrencyLimit (or defaultConcurrencyLimit). A request that cannot +// immediately acquire a slot is rejected with 429. Models without a local +// config entry (e.g. peer-routed models) are not limited. +func CreateConcurrencyMiddleware(cfg config.Config) chain.Middleware { + semaphores := make(map[string]*semaphore.Weighted, len(cfg.Models)) + for id, mc := range cfg.Models { + limit := defaultConcurrencyLimit + if mc.ConcurrencyLimit > 0 { + limit = mc.ConcurrencyLimit + } + semaphores[id] = semaphore.NewWeighted(int64(limit)) + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + data, err := router.FetchContext(r, cfg) + if err != nil { + router.SendError(w, r, router.ErrNoModelInContext) + return + } + + // fall through for peer models + sem, ok := semaphores[data.ModelID] + if !ok { + next.ServeHTTP(w, r) + return + } + if !sem.TryAcquire(1) { + http.Error(w, "Too many requests", http.StatusTooManyRequests) + return + } + defer sem.Release(1) + next.ServeHTTP(w, r) + }) + } +} diff --git a/internal/server/concurrency_test.go b/internal/server/concurrency_test.go new file mode 100644 index 0000000..c9aa91f --- /dev/null +++ b/internal/server/concurrency_test.go @@ -0,0 +1,75 @@ +package server + +import ( + "net/http" + "net/http/httptest" + "sync" + "testing" + + "github.com/mostlygeek/llama-swap/internal/config" + "github.com/mostlygeek/llama-swap/internal/router" +) + +func concurrencyTestReq(model string) *http.Request { + r := httptest.NewRequest("GET", "/v1/chat/completions", nil) + return r.WithContext(router.SetContext(r.Context(), router.ReqContextData{Model: model, ModelID: model})) +} + +func TestServer_ConcurrencyMiddleware_RejectsOverLimit(t *testing.T) { + cfg := config.Config{ + Models: map[string]config.ModelConfig{ + "m1": {ConcurrencyLimit: 1}, + }, + } + + entered := make(chan struct{}) + release := make(chan struct{}) + var once sync.Once + final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + once.Do(func() { close(entered) }) + <-release + w.WriteHeader(http.StatusOK) + }) + h := CreateConcurrencyMiddleware(cfg)(final) + + // First request occupies the only slot. + done := make(chan struct{}) + go func() { + defer close(done) + h.ServeHTTP(httptest.NewRecorder(), concurrencyTestReq("m1")) + }() + <-entered + + // Second concurrent request is rejected with 429. + w := httptest.NewRecorder() + h.ServeHTTP(w, concurrencyTestReq("m1")) + if w.Code != http.StatusTooManyRequests { + t.Fatalf("over-limit status = %d, want 429", w.Code) + } + + // Once the slot frees, a new request succeeds. + close(release) + <-done + w = httptest.NewRecorder() + h.ServeHTTP(w, concurrencyTestReq("m1")) + if w.Code != http.StatusOK { + t.Fatalf("post-release status = %d, want 200", w.Code) + } +} + +func TestServer_ConcurrencyMiddleware_UnconfiguredModelPassesThrough(t *testing.T) { + cfg := config.Config{Models: map[string]config.ModelConfig{}} + + called := 0 + final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called++ + w.WriteHeader(http.StatusOK) + }) + h := CreateConcurrencyMiddleware(cfg)(final) + + w := httptest.NewRecorder() + h.ServeHTTP(w, concurrencyTestReq("peer-model")) + if w.Code != http.StatusOK || called != 1 { + t.Fatalf("unconfigured model: status=%d called=%d, want 200/1", w.Code, called) + } +} diff --git a/internal/server/extras_test.go b/internal/server/extras_test.go new file mode 100644 index 0000000..f881ce4 --- /dev/null +++ b/internal/server/extras_test.go @@ -0,0 +1,205 @@ +package server + +import ( + "bytes" + "compress/flate" + "compress/gzip" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/mostlygeek/llama-swap/internal/config" + "github.com/mostlygeek/llama-swap/internal/logmon" +) + +func TestServer_DecompressBody(t *testing.T) { + plain := []byte("hello world") + + var gz bytes.Buffer + gw := gzip.NewWriter(&gz) + gw.Write(plain) + gw.Close() + + var fl bytes.Buffer + fw, _ := flate.NewWriter(&fl, flate.DefaultCompression) + fw.Write(plain) + fw.Close() + + cases := []struct { + name string + body []byte + encoding string + }{ + {"plain", plain, ""}, + {"gzip", gz.Bytes(), "gzip"}, + {"deflate", fl.Bytes(), "deflate"}, + {"unknown passthrough", plain, "br"}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + got, err := decompressBody(c.body, c.encoding) + if err != nil { + t.Fatalf("decompressBody: %v", err) + } + if !bytes.Equal(got, plain) { + t.Errorf("got %q, want %q", got, plain) + } + }) + } +} + +func TestServer_FilterAcceptEncoding(t *testing.T) { + cases := []struct { + in string + want string + }{ + {"", ""}, + {"gzip, deflate, br", "gzip, deflate"}, + {"br, zstd", ""}, + {"gzip;q=1.0", "gzip;q=1.0"}, + } + for _, c := range cases { + if got := filterAcceptEncoding(c.in); got != c.want { + t.Errorf("filterAcceptEncoding(%q) = %q, want %q", c.in, got, c.want) + } + } +} + +func TestServer_BodyCopier_Flush(t *testing.T) { + bc := newBodyCopier(httptest.NewRecorder()) + bc.Write([]byte("data")) + bc.Flush() + if bc.Status() != http.StatusOK { + t.Errorf("status = %d, want 200", bc.Status()) + } +} + +func TestServer_HeaderMapAndRedact(t *testing.T) { + h := http.Header{ + "Content-Type": {"application/json"}, + "Authorization": {"Bearer secret"}, + "X-Api-Key": {"key123"}, + } + m := headerMap(h) + if m["Content-Type"] != "application/json" { + t.Errorf("Content-Type = %q", m["Content-Type"]) + } + + redactHeaders(m) + if m["Authorization"] != "[REDACTED]" || m["X-Api-Key"] != "[REDACTED]" { + t.Errorf("sensitive headers not redacted: %v", m) + } + if m["Content-Type"] != "application/json" { + t.Error("non-sensitive header should not be redacted") + } +} + +func TestServer_StripVersionPrefix(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/v/v1/chat", nil) + stripVersionPrefix(r) + if r.URL.Path != "/v1/chat" { + t.Errorf("path = %q, want /v1/chat", r.URL.Path) + } + + r2 := httptest.NewRequest(http.MethodGet, "/v1/chat", nil) + stripVersionPrefix(r2) + if r2.URL.Path != "/v1/chat" { + t.Errorf("path = %q, want unchanged", r2.URL.Path) + } +} + +func TestServer_CloseStreams(t *testing.T) { + s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, "")) + s.CloseStreams() + select { + case <-s.shutdownCtx.Done(): + default: + t.Error("CloseStreams did not cancel shutdown context") + } + s.CloseStreams() // idempotent +} + +func TestServer_HandleUIAndFavicon(t *testing.T) { + s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, "")) + + for _, path := range []string{"/ui/", "/favicon.ico"} { + w := httptest.NewRecorder() + s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, path, nil)) + // The embedded ui_dist only carries placeholder.txt in test builds, so + // these resolve to 404 — the handlers still execute end to end. + if w.Code != http.StatusOK && w.Code != http.StatusNotFound { + t.Errorf("%s: status = %d", path, w.Code) + } + } +} + +func TestServer_HandleAPIUnloadAll(t *testing.T) { + local := newStubRouter([]string{"m1"}, "") + s := newTestServer(local, newStubRouter(nil, "")) + + w := httptest.NewRecorder() + s.ServeHTTP(w, httptest.NewRequest(http.MethodPost, "/api/models/unload", nil)) + + if w.Code != http.StatusOK { + t.Fatalf("status = %d", w.Code) + } + if local.unloadCalls.Load() != 1 { + t.Errorf("unloadCalls = %d, want 1", local.unloadCalls.Load()) + } +} + +func TestServer_HandleAPIUnloadModel(t *testing.T) { + local := newStubRouter([]string{"m1"}, "") + s := newTestServer(local, newStubRouter(nil, "")) + s.cfg = config.Config{Models: map[string]config.ModelConfig{"m1": {}}} + + t.Run("known model", func(t *testing.T) { + w := httptest.NewRecorder() + s.ServeHTTP(w, httptest.NewRequest(http.MethodPost, "/api/models/unload/m1", nil)) + if w.Code != http.StatusOK { + t.Errorf("status = %d, want 200", w.Code) + } + }) + + t.Run("unknown model 404", func(t *testing.T) { + w := httptest.NewRecorder() + s.ServeHTTP(w, httptest.NewRequest(http.MethodPost, "/api/models/unload/nope", nil)) + if w.Code != http.StatusNotFound { + t.Errorf("status = %d, want 404", w.Code) + } + }) +} + +func TestServer_HandleAPICapture(t *testing.T) { + s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, "")) + s.metrics = newMetricsMonitor(logmon.NewWriter(io.Discard), 100, 5) + s.metrics.addCapture(ReqRespCapture{ID: 42, ReqPath: "/v1/chat/completions"}) + + t.Run("found", func(t *testing.T) { + w := httptest.NewRecorder() + s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/captures/42", nil)) + if w.Code != http.StatusOK { + t.Fatalf("status = %d", w.Code) + } + if !bytes.Contains(w.Body.Bytes(), []byte("/v1/chat/completions")) { + t.Errorf("body = %q", w.Body.String()) + } + }) + + t.Run("not found", func(t *testing.T) { + w := httptest.NewRecorder() + s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/captures/999", nil)) + if w.Code != http.StatusNotFound { + t.Errorf("status = %d, want 404", w.Code) + } + }) + + t.Run("invalid id", func(t *testing.T) { + w := httptest.NewRecorder() + s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/captures/abc", nil)) + if w.Code != http.StatusBadRequest { + t.Errorf("status = %d, want 400", w.Code) + } + }) +} diff --git a/internal/server/filters.go b/internal/server/filters.go new file mode 100644 index 0000000..209e14c --- /dev/null +++ b/internal/server/filters.go @@ -0,0 +1,218 @@ +package server + +import ( + "bytes" + "fmt" + "io" + "mime/multipart" + "net/http" + "strconv" + "strings" + + "github.com/mostlygeek/llama-swap/internal/chain" + "github.com/mostlygeek/llama-swap/internal/config" + "github.com/mostlygeek/llama-swap/internal/router" + "github.com/tidwall/sjson" +) + +// CreateFilterMiddleware returns middleware that applies per-model request-body +// filters to JSON requests before they are forwarded upstream: +// +// - UseModelName rewrite (issue #69) +// - StripParams removal (issue #174) +// - SetParams injection (issue #453) +// - SetParamsByID per-alias overrides +// +// Non-JSON requests (GET, multipart forms) pass through untouched. The buffered +// body is re-attached with Content-Length / Transfer-Encoding cleanup so the +// downstream reverse proxy forwards the correct bytes (see issue #11). +func CreateFilterMiddleware(cfg config.Config) chain.Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.Contains(r.Header.Get("Content-Type"), "application/json") { + next.ServeHTTP(w, r) + return + } + + data, err := router.FetchContext(r, cfg) + if err != nil { + router.SendError(w, r, router.ErrNoModelInContext) + return + } + + useModelName, filters, ok := resolveFilters(cfg, data.Model) + if !ok { + next.ServeHTTP(w, r) + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + router.SendResponse(w, r, http.StatusBadRequest, "could not read request body") + return + } + + body, err = applyFilters(body, data.Model, useModelName, filters) + if err != nil { + router.SendResponse(w, r, http.StatusInternalServerError, err.Error()) + return + } + + r.Body = io.NopCloser(bytes.NewReader(body)) + r.Header.Del("Transfer-Encoding") + r.Header.Set("Content-Length", strconv.Itoa(len(body))) + r.ContentLength = int64(len(body)) + + next.ServeHTTP(w, r) + }) + } +} + +// CreateFormFilterMiddleware returns middleware that applies the UseModelName +// rewrite (issue #69) to multipart/form-data requests before they are forwarded +// upstream. JSON-body filters (StripParams, SetParams) do not apply to form +// endpoints; only the "model" field is rewritten. +// +// Non-multipart requests pass through untouched. When a rewrite is needed the +// form is reconstructed and re-attached with Content-Type / Content-Length +// cleanup so the downstream reverse proxy forwards the correct bytes. +func CreateFormFilterMiddleware(cfg config.Config) chain.Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.Contains(r.Header.Get("Content-Type"), "multipart/form-data") { + next.ServeHTTP(w, r) + return + } + + data, err := router.FetchContext(r, cfg) + if err != nil { + router.SendError(w, r, router.ErrNoModelInContext) + return + } + + useModelName, _, ok := resolveFilters(cfg, data.Model) + if !ok || useModelName == "" { + next.ServeHTTP(w, r) + return + } + + if err := r.ParseMultipartForm(32 << 20); err != nil { + router.SendResponse(w, r, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error())) + return + } + + body, contentType, err := rewriteMultipartModel(r.MultipartForm, useModelName) + if err != nil { + router.SendResponse(w, r, http.StatusInternalServerError, err.Error()) + return + } + + r.Body = io.NopCloser(bytes.NewReader(body)) + r.MultipartForm = nil + r.Header.Del("Transfer-Encoding") + r.Header.Set("Content-Type", contentType) + r.Header.Set("Content-Length", strconv.Itoa(len(body))) + r.ContentLength = int64(len(body)) + + next.ServeHTTP(w, r) + }) + } +} + +// rewriteMultipartModel reconstructs a multipart form, replacing the "model" +// field value with useModelName. It returns the encoded body and the matching +// Content-Type header (which carries the generated boundary). +func rewriteMultipartModel(form *multipart.Form, useModelName string) ([]byte, string, error) { + var buf bytes.Buffer + mw := multipart.NewWriter(&buf) + + for key, values := range form.Value { + for _, value := range values { + if key == "model" { + value = useModelName + } + field, err := mw.CreateFormField(key) + if err != nil { + return nil, "", fmt.Errorf("error recreating form field %s: %w", key, err) + } + if _, err := field.Write([]byte(value)); err != nil { + return nil, "", fmt.Errorf("error writing form field %s: %w", key, err) + } + } + } + + for key, headers := range form.File { + for _, fh := range headers { + part, err := mw.CreateFormFile(key, fh.Filename) + if err != nil { + return nil, "", fmt.Errorf("error recreating form file %s: %w", key, err) + } + file, err := fh.Open() + if err != nil { + return nil, "", fmt.Errorf("error opening uploaded file %s: %w", key, err) + } + if _, err := io.Copy(part, file); err != nil { + file.Close() + return nil, "", fmt.Errorf("error copying file data %s: %w", key, err) + } + file.Close() + } + } + + if err := mw.Close(); err != nil { + return nil, "", fmt.Errorf("error finalizing multipart form: %w", err) + } + return buf.Bytes(), mw.FormDataContentType(), nil +} + +// resolveFilters returns the filter settings for a requested model. UseModelName +// only applies to local models; peers carry filters but no name rewrite. +func resolveFilters(cfg config.Config, requested string) (useModelName string, filters config.Filters, ok bool) { + if realName, found := cfg.RealModelName(requested); found { + mc := cfg.Models[realName] + return mc.UseModelName, mc.Filters.Filters, true + } + for _, peer := range cfg.Peers { + for _, m := range peer.Models { + if m == requested { + return "", peer.Filters, true + } + } + } + return "", config.Filters{}, false +} + +// applyFilters rewrites the JSON body in place. Order matches the legacy +// ProxyManager: useModelName, stripParams, setParams, then setParamsByID (which +// can override setParams). +func applyFilters(body []byte, requested, useModelName string, f config.Filters) ([]byte, error) { + var err error + + if useModelName != "" { + if body, err = sjson.SetBytes(body, "model", useModelName); err != nil { + return nil, fmt.Errorf("error rewriting model name in JSON: %w", err) + } + } + + for _, param := range f.SanitizedStripParams() { + if body, err = sjson.DeleteBytes(body, param); err != nil { + return nil, fmt.Errorf("error stripping parameter %s from request", param) + } + } + + setParams, setKeys := f.SanitizedSetParams() + for _, key := range setKeys { + if body, err = sjson.SetBytes(body, key, setParams[key]); err != nil { + return nil, fmt.Errorf("error setting parameter %s in request", key) + } + } + + byID, byIDKeys := f.SanitizedSetParamsByID(requested) + for _, key := range byIDKeys { + if body, err = sjson.SetBytes(body, key, byID[key]); err != nil { + return nil, fmt.Errorf("error setting parameter %s in request", key) + } + } + + return body, nil +} diff --git a/internal/server/filters_test.go b/internal/server/filters_test.go new file mode 100644 index 0000000..87226bc --- /dev/null +++ b/internal/server/filters_test.go @@ -0,0 +1,132 @@ +package server + +import ( + "bytes" + "io" + "mime/multipart" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/mostlygeek/llama-swap/internal/config" + "github.com/tidwall/gjson" +) + +func TestServer_ApplyFilters(t *testing.T) { + t.Run("useModelName rewrite", func(t *testing.T) { + out, err := applyFilters([]byte(`{"model":"alias","temp":1}`), "alias", "real-model", config.Filters{}) + if err != nil { + t.Fatalf("applyFilters: %v", err) + } + if got := gjson.GetBytes(out, "model").String(); got != "real-model" { + t.Errorf("model = %q, want real-model", got) + } + }) + + t.Run("strip and set params", func(t *testing.T) { + f := config.Filters{ + StripParams: "temperature", + SetParams: map[string]any{"top_p": 0.9}, + } + out, err := applyFilters([]byte(`{"model":"m","temperature":0.7}`), "m", "", f) + if err != nil { + t.Fatalf("applyFilters: %v", err) + } + if gjson.GetBytes(out, "temperature").Exists() { + t.Error("temperature should be stripped") + } + if got := gjson.GetBytes(out, "top_p").Float(); got != 0.9 { + t.Errorf("top_p = %v, want 0.9", got) + } + }) + + t.Run("setParamsByID overrides setParams", func(t *testing.T) { + f := config.Filters{ + SetParams: map[string]any{"top_p": 0.5}, + SetParamsByID: map[string]map[string]any{"alias": {"top_p": 0.1}}, + } + out, err := applyFilters([]byte(`{"model":"alias"}`), "alias", "", f) + if err != nil { + t.Fatalf("applyFilters: %v", err) + } + if got := gjson.GetBytes(out, "top_p").Float(); got != 0.1 { + t.Errorf("top_p = %v, want 0.1", got) + } + }) +} + +func TestServer_RewriteMultipartModel(t *testing.T) { + var buf bytes.Buffer + mw := multipart.NewWriter(&buf) + mw.WriteField("model", "old-name") + mw.WriteField("language", "en") + fw, _ := mw.CreateFormFile("file", "audio.wav") + fw.Write([]byte("RIFFdata")) + mw.Close() + + r := httptest.NewRequest(http.MethodPost, "/v1/audio/transcriptions", &buf) + r.Header.Set("Content-Type", mw.FormDataContentType()) + if err := r.ParseMultipartForm(32 << 20); err != nil { + t.Fatalf("ParseMultipartForm: %v", err) + } + + body, contentType, err := rewriteMultipartModel(r.MultipartForm, "new-name") + if err != nil { + t.Fatalf("rewriteMultipartModel: %v", err) + } + + parsed, err := multipart.NewReader(bytes.NewReader(body), boundaryOf(t, contentType)).ReadForm(32 << 20) + if err != nil { + t.Fatalf("re-parse: %v", err) + } + if got := parsed.Value["model"][0]; got != "new-name" { + t.Errorf("model = %q, want new-name", got) + } + if got := parsed.Value["language"][0]; got != "en" { + t.Errorf("language = %q, want en", got) + } + fh := parsed.File["file"][0] + f, _ := fh.Open() + data, _ := io.ReadAll(f) + f.Close() + if string(data) != "RIFFdata" { + t.Errorf("file data = %q, want RIFFdata", data) + } +} + +func boundaryOf(t *testing.T, contentType string) string { + t.Helper() + _, params, ok := strings.Cut(contentType, "boundary=") + if !ok { + t.Fatalf("no boundary in %q", contentType) + } + return params +} + +func TestServer_FormFilterMiddleware(t *testing.T) { + cfg := config.Config{Models: map[string]config.ModelConfig{ + "whisper": {UseModelName: "whisper-large-v3"}, + }} + + var buf bytes.Buffer + mw := multipart.NewWriter(&buf) + mw.WriteField("model", "whisper") + fw, _ := mw.CreateFormFile("file", "a.wav") + fw.Write([]byte("xx")) + mw.Close() + + r := httptest.NewRequest(http.MethodPost, "/v1/audio/transcriptions", &buf) + r.Header.Set("Content-Type", mw.FormDataContentType()) + + var gotModel string + final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = r.ParseMultipartForm(32 << 20) + gotModel = r.MultipartForm.Value["model"][0] + }) + CreateFormFilterMiddleware(cfg)(final).ServeHTTP(httptest.NewRecorder(), r) + + if gotModel != "whisper-large-v3" { + t.Errorf("model rewritten to %q, want whisper-large-v3", gotModel) + } +} diff --git a/internal/server/inflight.go b/internal/server/inflight.go new file mode 100644 index 0000000..b5b1d2f --- /dev/null +++ b/internal/server/inflight.go @@ -0,0 +1,33 @@ +package server + +import ( + "net/http" + "sync/atomic" + + "github.com/mostlygeek/llama-swap/internal/chain" + "github.com/mostlygeek/llama-swap/internal/event" + "github.com/mostlygeek/llama-swap/internal/shared" +) + +// inflightCounter tracks the number of in-flight model-dispatched requests. +type inflightCounter struct { + total atomic.Int64 +} + +func (c *inflightCounter) Increment() int64 { return c.total.Add(1) } +func (c *inflightCounter) Decrement() int64 { return c.total.Add(-1) } +func (c *inflightCounter) Current() int64 { return c.total.Load() } + +// CreateInflightMiddleware returns middleware that increments the counter on +// entry and decrements on exit, emitting an InFlightRequestsEvent for each. +func CreateInflightMiddleware(c *inflightCounter) chain.Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + event.Emit(shared.InFlightRequestsEvent{Total: int(c.Increment())}) + defer func() { + event.Emit(shared.InFlightRequestsEvent{Total: int(c.Decrement())}) + }() + next.ServeHTTP(w, r) + }) + } +} diff --git a/internal/server/log.go b/internal/server/log.go new file mode 100644 index 0000000..2a7d9d5 --- /dev/null +++ b/internal/server/log.go @@ -0,0 +1,222 @@ +package server + +import ( + "context" + "fmt" + "io" + "net" + "net/http" + "os" + "strings" + "time" + + "github.com/mostlygeek/llama-swap/internal/chain" + "github.com/mostlygeek/llama-swap/internal/config" + "github.com/mostlygeek/llama-swap/internal/logmon" + "github.com/mostlygeek/llama-swap/internal/router" +) + +// NewLoggers builds the proxy, upstream, and combined (mux) log monitors, +// wiring each one's output per the logToStdout config value. The proxy and +// upstream monitors write into muxlog (rather than os.Stdout directly) so +// muxlog accumulates a combined history for the /logs endpoints, while each +// monitor keeps its own per-source history and event subscribers. +// +// Behaviour matches the legacy ProxyManager: +// +// - none: everything discarded +// - both: proxy + upstream both routed to muxlog -> stdout +// - upstream: only upstream routed to muxlog -> stdout; proxy discarded +// - proxy: only proxy routed to muxlog -> stdout; upstream discarded +// +// An empty or unrecognised value behaves like "proxy". +func NewLoggers(logToStdout string) (muxlog, proxylog, upstreamlog *logmon.Monitor) { + switch logToStdout { + case config.LogToStdoutNone: + muxlog = logmon.NewWriter(io.Discard) + proxylog = logmon.NewWriter(io.Discard) + upstreamlog = logmon.NewWriter(io.Discard) + case config.LogToStdoutBoth: + muxlog = logmon.NewWriter(os.Stdout) + proxylog = logmon.NewWriter(muxlog) + upstreamlog = logmon.NewWriter(muxlog) + case config.LogToStdoutUpstream: + muxlog = logmon.NewWriter(os.Stdout) + proxylog = logmon.NewWriter(io.Discard) + upstreamlog = logmon.NewWriter(muxlog) + default: + // config.LogToStdoutProxy, and the fallback for an unset value. + muxlog = logmon.NewWriter(os.Stdout) + proxylog = logmon.NewWriter(muxlog) + upstreamlog = logmon.NewWriter(io.Discard) + } + return muxlog, proxylog, upstreamlog +} + +// handleLogs serves the historical proxy/upstream log. HTML clients are +// redirected to the UI. +func (s *Server) handleLogs(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.Header.Get("Accept"), "text/html") { + http.Redirect(w, r, "/ui/", http.StatusFound) + return + } + w.Header().Set("Content-Type", "text/plain") + w.Write(s.muxlog.GetHistory()) +} + +// getLogger resolves a log monitor by id. An empty id maps to the combined +// muxlog; "proxy" and "upstream" select the respective monitors. +func (s *Server) getLogger(logMonitorID string) (*logmon.Monitor, error) { + switch logMonitorID { + case "": + return s.muxlog, nil + case "proxy": + return s.proxylog, nil + case "upstream": + return s.upstreamlog, nil + default: + if _, modelID, _, found := findModelInPath(s.cfg, "/"+logMonitorID); found { + if log, ok := s.local.ProcessLogger(modelID); ok { + return log, nil + } + } + return nil, fmt.Errorf("invalid logger. Use 'proxy', 'upstream' or a model's ID") + } +} + +// handleLogStream tails a log monitor: it writes the history then streams live +// log data until the client disconnects or the server shuts down. +func (s *Server) handleLogStream(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + w.Header().Set("Transfer-Encoding", "chunked") + w.Header().Set("X-Content-Type-Options", "nosniff") + // prevent nginx from buffering streamed logs + w.Header().Set("X-Accel-Buffering", "no") + + logMonitorID := strings.TrimPrefix(r.PathValue("logMonitorID"), "/") + // Strip a query string if it leaked into the path segment. + if idx := strings.Index(logMonitorID, "?"); idx != -1 { + logMonitorID = logMonitorID[:idx] + } + + logger, err := s.getLogger(logMonitorID) + if err != nil { + router.SendResponse(w, r, http.StatusBadRequest, err.Error()) + return + } + + flusher, ok := w.(http.Flusher) + if !ok { + router.SendResponse(w, r, http.StatusInternalServerError, "streaming unsupported") + return + } + + _, skipHistory := r.URL.Query()["no-history"] + if !skipHistory { + if history := logger.GetHistory(); len(history) != 0 { + w.Write(history) + flusher.Flush() + } + } + + sendChan := make(chan []byte, 10) + ctx, cancel := context.WithCancel(r.Context()) + defer cancel() + cancelSub := logger.OnLogData(func(data []byte) { + select { + case sendChan <- data: + case <-ctx.Done(): + default: + } + }) + defer cancelSub() + + for { + select { + case <-r.Context().Done(): + return + case <-s.shutdownCtx.Done(): + return + case data := <-sendChan: + w.Write(data) + flusher.Flush() + } + } +} + +// requestLogPathSkips lists path prefixes excluded from the access log because +// they are polled frequently and would drown out useful entries. +var requestLogPathSkips = []string{"/wol-health", "/api/performance", "/metrics"} + +// statusRecorder wraps an http.ResponseWriter to capture the response status +// code and the number of body bytes written, so the access log can report +// them. Flush is forwarded so streaming handlers (SSE) still work. +type statusRecorder struct { + http.ResponseWriter + status int + size int +} + +func (sr *statusRecorder) WriteHeader(code int) { + sr.status = code + sr.ResponseWriter.WriteHeader(code) +} + +func (sr *statusRecorder) Write(b []byte) (int, error) { + n, err := sr.ResponseWriter.Write(b) + sr.size += n + return n, err +} + +func (sr *statusRecorder) Flush() { + if f, ok := sr.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +// clientIP resolves the originating client address, preferring proxy headers +// over the raw connection address. +func clientIP(r *http.Request) string { + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + if first, _, found := strings.Cut(xff, ","); found { + return strings.TrimSpace(first) + } + return strings.TrimSpace(xff) + } + if xr := r.Header.Get("X-Real-IP"); xr != "" { + return strings.TrimSpace(xr) + } + if host, _, err := net.SplitHostPort(r.RemoteAddr); err == nil { + return host + } + return r.RemoteAddr +} + +// CreateRequestLogMiddleware returns middleware that records one access-log +// line per request to proxylog, in the legacy format: +// +// clientIP "METHOD PATH PROTO" status bodySize "UA" duration +// +// Frequently-polled health/metrics paths are skipped. The path is captured +// before next runs because /upstream rewrites the request URL in place. +func CreateRequestLogMiddleware(proxylog *logmon.Monitor) chain.Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for _, prefix := range requestLogPathSkips { + if strings.HasPrefix(r.URL.Path, prefix) { + next.ServeHTTP(w, r) + return + } + } + + start := time.Now() + ip, method, path, proto, ua := clientIP(r), r.Method, r.URL.Path, r.Proto, r.UserAgent() + + rec := &statusRecorder{ResponseWriter: w, status: http.StatusOK} + next.ServeHTTP(rec, r) + + proxylog.Infof("Request %s \"%s %s %s\" %d %d \"%s\" %v", + ip, method, path, proto, rec.status, rec.size, ua, time.Since(start)) + }) + } +} diff --git a/internal/server/log_test.go b/internal/server/log_test.go new file mode 100644 index 0000000..1af8cc4 --- /dev/null +++ b/internal/server/log_test.go @@ -0,0 +1,137 @@ +package server + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/mostlygeek/llama-swap/internal/config" + "github.com/mostlygeek/llama-swap/internal/logmon" +) + +func TestServer_NewLoggers(t *testing.T) { + t.Run("proxy mode routes proxy into muxlog, discards upstream", func(t *testing.T) { + mux, proxy, upstream := NewLoggers(config.LogToStdoutProxy) + proxy.Info("PROXYLINE") + upstream.Info("UPSTREAMLINE") + h := string(mux.GetHistory()) + if !strings.Contains(h, "PROXYLINE") { + t.Errorf("muxlog missing proxy line: %q", h) + } + if strings.Contains(h, "UPSTREAMLINE") { + t.Errorf("muxlog should not contain upstream line: %q", h) + } + }) + + t.Run("both mode routes proxy and upstream into muxlog", func(t *testing.T) { + mux, proxy, upstream := NewLoggers(config.LogToStdoutBoth) + proxy.Info("PROXYLINE") + upstream.Info("UPSTREAMLINE") + h := string(mux.GetHistory()) + if !strings.Contains(h, "PROXYLINE") || !strings.Contains(h, "UPSTREAMLINE") { + t.Errorf("muxlog history = %q", h) + } + }) + + t.Run("none mode discards everything from muxlog", func(t *testing.T) { + mux, proxy, upstream := NewLoggers(config.LogToStdoutNone) + proxy.Info("PROXYLINE") + upstream.Info("UPSTREAMLINE") + if len(mux.GetHistory()) != 0 { + t.Errorf("muxlog should be empty, got %q", mux.GetHistory()) + } + }) +} + +func TestServer_HandleLogs_Plain(t *testing.T) { + s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, "")) + s.muxlog.Write([]byte("a log line")) + + w := httptest.NewRecorder() + s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/logs", nil)) + + if w.Code != http.StatusOK { + t.Fatalf("status = %d", w.Code) + } + if ct := w.Header().Get("Content-Type"); ct != "text/plain" { + t.Errorf("Content-Type = %q, want text/plain", ct) + } + if w.Body.String() != "a log line" { + t.Errorf("body = %q", w.Body.String()) + } +} + +func TestServer_HandleLogs_HTMLRedirect(t *testing.T) { + s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, "")) + + req := httptest.NewRequest(http.MethodGet, "/logs", nil) + req.Header.Set("Accept", "text/html") + w := httptest.NewRecorder() + s.ServeHTTP(w, req) + + if w.Code != http.StatusFound { + t.Fatalf("status = %d, want 302", w.Code) + } + if got := w.Header().Get("Location"); got != "/ui/" { + t.Errorf("Location = %q, want /ui/", got) + } +} + +func TestServer_ClientIP(t *testing.T) { + cases := []struct { + name string + setup func(*http.Request) + want string + }{ + {"remote addr", func(r *http.Request) { r.RemoteAddr = "10.0.0.5:1234" }, "10.0.0.5"}, + {"x-forwarded-for", func(r *http.Request) { + r.Header.Set("X-Forwarded-For", "1.2.3.4, 5.6.7.8") + }, "1.2.3.4"}, + {"x-real-ip", func(r *http.Request) { r.Header.Set("X-Real-IP", "9.9.9.9") }, "9.9.9.9"}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.RemoteAddr = "" + c.setup(r) + if got := clientIP(r); got != c.want { + t.Errorf("clientIP() = %q, want %q", got, c.want) + } + }) + } +} + +func TestServer_RequestLogMiddleware(t *testing.T) { + proxylog := logmon.NewWriter(io.Discard) + final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusCreated) + w.Write([]byte("hello")) + }) + mw := CreateRequestLogMiddleware(proxylog) + + t.Run("logs request", func(t *testing.T) { + r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + r.RemoteAddr = "192.168.1.1:5000" + mw(final).ServeHTTP(httptest.NewRecorder(), r) + + line := string(proxylog.GetHistory()) + for _, want := range []string{"192.168.1.1", "POST /v1/chat/completions", "201", "5"} { + if !strings.Contains(line, want) { + t.Errorf("log line %q missing %q", line, want) + } + } + }) + + for _, path := range []string{"/wol-health", "/api/performance", "/metrics"} { + t.Run("skips "+path, func(t *testing.T) { + skipLog := logmon.NewWriter(io.Discard) + skipMW := CreateRequestLogMiddleware(skipLog) + skipMW(final).ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, path, nil)) + if len(skipLog.GetHistory()) != 0 { + t.Errorf("%s should not be logged; got %q", path, skipLog.GetHistory()) + } + }) + } +} diff --git a/internal/server/metrics.go b/internal/server/metrics.go new file mode 100644 index 0000000..71ef1f8 --- /dev/null +++ b/internal/server/metrics.go @@ -0,0 +1,450 @@ +package server + +import ( + "bytes" + "compress/flate" + "compress/gzip" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "sync" + "time" + + "github.com/mostlygeek/llama-swap/internal/cache" + "github.com/mostlygeek/llama-swap/internal/event" + "github.com/mostlygeek/llama-swap/internal/logmon" + "github.com/mostlygeek/llama-swap/internal/ring" + "github.com/mostlygeek/llama-swap/internal/shared" + "github.com/tidwall/gjson" +) + +// TokenMetrics holds token usage and performance metrics. +type TokenMetrics struct { + CachedTokens int `json:"cache_tokens"` + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + PromptPerSecond float64 `json:"prompt_per_second"` + TokensPerSecond float64 `json:"tokens_per_second"` +} + +// ActivityLogEntry represents parsed token statistics from llama-server logs. +type ActivityLogEntry struct { + ID int `json:"id"` + Timestamp time.Time `json:"timestamp"` + Model string `json:"model"` + ReqPath string `json:"req_path"` + RespContentType string `json:"resp_content_type"` + RespStatusCode int `json:"resp_status_code"` + Tokens TokenMetrics `json:"tokens"` + DurationMs int `json:"duration_ms"` + HasCapture bool `json:"has_capture"` +} + +// ActivityLogEvent carries a single activity log entry to event subscribers. +type ActivityLogEvent struct { + Metrics ActivityLogEntry +} + +func (e ActivityLogEvent) Type() uint32 { + return shared.ActivityLogEventID +} + +// metricsMonitor parses upstream responses for token statistics, keeps a +// bounded in-memory ring of recent activity, and (when captures are enabled) +// stores zstd+CBOR-compressed request/response captures in a sized cache. +type metricsMonitor struct { + mu sync.RWMutex + metrics ring.Buffer[ActivityLogEntry] + nextID int + logger *logmon.Monitor + + enableCaptures bool + captureCache *cache.Cache // zstd-compressed CBOR of ReqRespCapture +} + +// newMetricsMonitor creates a metricsMonitor retaining up to maxMetrics entries. +// captureBufferMB is the capture buffer size in megabytes; 0 disables captures. +func newMetricsMonitor(logger *logmon.Monitor, maxMetrics int, captureBufferMB int) *metricsMonitor { + if maxMetrics <= 0 { + maxMetrics = 1000 + } + mm := &metricsMonitor{ + logger: logger, + metrics: ring.NewBuffer[ActivityLogEntry](maxMetrics), + enableCaptures: captureBufferMB > 0, + } + if captureBufferMB > 0 { + mm.captureCache = cache.New(captureBufferMB * 1024 * 1024) + } + return mm +} + +// queueMetrics adds a metric to the ring and returns its assigned ID. +func (mp *metricsMonitor) queueMetrics(metric ActivityLogEntry) int { + mp.mu.Lock() + defer mp.mu.Unlock() + + metric.ID = mp.nextID + mp.nextID++ + mp.metrics.Push(metric) + return metric.ID +} + +// emitMetric publishes an ActivityLogEvent for the given metric. +func (mp *metricsMonitor) emitMetric(metric ActivityLogEntry) { + event.Emit(ActivityLogEvent{Metrics: metric}) +} + +// getMetrics returns a copy of the current metrics. +func (mp *metricsMonitor) getMetrics() []ActivityLogEntry { + mp.mu.RLock() + defer mp.mu.RUnlock() + + result := mp.metrics.Slice() + if result == nil { + return []ActivityLogEntry{} + } + if mp.captureCache != nil { + for i := range result { + result[i].HasCapture = mp.captureCache.Has(result[i].ID) + } + } + return result +} + +// getMetricsJSON returns the current metrics as a JSON array. +func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) { + return json.Marshal(mp.getMetrics()) +} + +// record parses a completed response body and stores/emits an activity entry. +// When captures are enabled, a zstd+CBOR capture is stored for successful +// requests, with cf controlling which request/response parts are retained. +// reqBody and reqHeaders are the request data buffered before dispatch. +func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *responseBodyCopier, cf captureFields, reqBody []byte, reqHeaders map[string]string) { + tm := ActivityLogEntry{ + Timestamp: time.Now(), + Model: modelID, + ReqPath: r.URL.Path, + RespContentType: recorder.Header().Get("Content-Type"), + RespStatusCode: recorder.Status(), + DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()), + } + + queueAndEmit := func() { + tm.ID = mp.queueMetrics(tm) + mp.emitMetric(tm) + } + + if recorder.Status() != http.StatusOK { + mp.logger.Warnf("non-200 response, recording partial metrics: status=%d, path=%s", recorder.Status(), r.URL.Path) + queueAndEmit() + return + } + + body := recorder.body.Bytes() + if len(body) == 0 { + mp.logger.Warn("metrics: empty body, recording minimal metrics") + queueAndEmit() + return + } + + if encoding := recorder.Header().Get("Content-Encoding"); encoding != "" { + decoded, err := decompressBody(body, encoding) + if err != nil { + mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, r.URL.Path) + queueAndEmit() + return + } + body = decoded + } + + if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") { + if parsed, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil { + mp.logger.Warnf("error processing streaming response: %v, path=%s, recording minimal metrics", err, r.URL.Path) + } else { + tm.Tokens = parsed.Tokens + tm.DurationMs = parsed.DurationMs + } + } else if gjson.ValidBytes(body) { + parsed := gjson.ParseBytes(body) + usage := parsed.Get("usage") + timings := parsed.Get("timings") + + // /infill responses are arrays; timings live in the last element (#463). + if strings.HasPrefix(r.URL.Path, "/infill") { + if arr := parsed.Array(); len(arr) > 0 { + timings = arr[len(arr)-1].Get("timings") + } + } + + if usage.Exists() || timings.Exists() { + if parsedMetrics, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil { + mp.logger.Warnf("error parsing metrics: %v, path=%s, recording minimal metrics", err, r.URL.Path) + } else { + tm.Tokens = parsedMetrics.Tokens + tm.DurationMs = parsedMetrics.DurationMs + } + } + } else { + mp.logger.Warnf("metrics: invalid JSON in response body path=%s, recording minimal metrics", r.URL.Path) + } + + tm.ID = mp.queueMetrics(tm) + if mp.enableCaptures { + capture := ReqRespCapture{ + ID: tm.ID, + ReqPath: r.URL.Path, + ReqHeaders: reqHeaders, + } + if cf&captureReqBody != 0 { + capture.ReqBody = reqBody + } + if cf&captureRespHeaders != 0 { + capture.RespHeaders = headerMap(recorder.Header()) + redactHeaders(capture.RespHeaders) + delete(capture.RespHeaders, "Content-Encoding") + } + if cf&captureRespBody != 0 { + capture.RespBody = body + } + if mp.addCapture(capture) { + tm.HasCapture = true + } + } + mp.emitMetric(tm) +} + +// usagePaths lists the JSON paths where a per-event usage object can live. +var usagePaths = []string{"usage", "response.usage", "message.usage"} + +// extractUsageTokens reads input/output/cached token counts from a usage +// gjson.Result, handling the field-name differences across endpoints. +func extractUsageTokens(usage gjson.Result) (input, output, cached int64, ok bool) { + cached = -1 + if !usage.Exists() { + return + } + + if v := usage.Get("prompt_tokens"); v.Exists() { + input = v.Int() + ok = true + } else if v := usage.Get("input_tokens"); v.Exists() { + input = v.Int() + ok = true + } + + if v := usage.Get("completion_tokens"); v.Exists() { + output = v.Int() + ok = true + } else if v := usage.Get("output_tokens"); v.Exists() { + output = v.Int() + ok = true + } + + if v := usage.Get("cache_read_input_tokens"); v.Exists() { + cached = v.Int() + ok = true + } else if v := usage.Get("input_tokens_details.cached_tokens"); v.Exists() { + cached = v.Int() + ok = true + } else if v := usage.Get("prompt_tokens_details.cached_tokens"); v.Exists() { + cached = v.Int() + ok = true + } + return +} + +func processStreamingResponse(modelID string, start time.Time, body []byte) (ActivityLogEntry, error) { + var ( + inputTokens, outputTokens int64 + cachedTokens int64 = -1 + hasAny bool + timings gjson.Result + ) + + prefix := []byte("data:") + for offset := 0; offset < len(body); { + nl := bytes.IndexByte(body[offset:], '\n') + var line []byte + if nl == -1 { + line = body[offset:] + offset = len(body) + } else { + line = body[offset : offset+nl] + offset += nl + 1 + } + + line = bytes.TrimSpace(line) + if len(line) == 0 || !bytes.HasPrefix(line, prefix) { + continue + } + data := bytes.TrimSpace(line[len(prefix):]) + if len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) { + continue + } + if !gjson.ValidBytes(data) { + continue + } + parsed := gjson.ParseBytes(data) + + for _, path := range usagePaths { + u := parsed.Get(path) + if !u.Exists() { + continue + } + i, o, c, ok := extractUsageTokens(u) + if !ok { + continue + } + hasAny = true + if i > 0 { + inputTokens = i + } + if o > 0 { + outputTokens = o + } + if c >= 0 { + cachedTokens = c + } + } + if t := parsed.Get("timings"); t.Exists() { + timings = t + hasAny = true + } + } + + if !hasAny { + return ActivityLogEntry{}, fmt.Errorf("no valid JSON data found in stream") + } + + return buildMetrics(modelID, start, inputTokens, outputTokens, cachedTokens, timings), nil +} + +func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) (ActivityLogEntry, error) { + input, output, cached, _ := extractUsageTokens(usage) + return buildMetrics(modelID, start, input, output, cached, timings), nil +} + +// buildMetrics composes an ActivityLogEntry from accumulated token counts and +// optional llama-server timings (which override input/output and provide rates). +func buildMetrics(modelID string, start time.Time, inputTokens, outputTokens, cachedTokens int64, timings gjson.Result) ActivityLogEntry { + wallDurationMs := int(time.Since(start).Milliseconds()) + durationMs := wallDurationMs + tokensPerSecond := -1.0 + promptPerSecond := -1.0 + + if timings.Exists() { + inputTokens = timings.Get("prompt_n").Int() + outputTokens = timings.Get("predicted_n").Int() + promptPerSecond = timings.Get("prompt_per_second").Float() + tokensPerSecond = timings.Get("predicted_per_second").Float() + timingsDurationMs := int(timings.Get("prompt_ms").Float() + timings.Get("predicted_ms").Float()) + if timingsDurationMs > durationMs { + durationMs = timingsDurationMs + } + if cachedValue := timings.Get("cache_n"); cachedValue.Exists() { + cachedTokens = cachedValue.Int() + } + } + + return ActivityLogEntry{ + Timestamp: time.Now(), + Model: modelID, + Tokens: TokenMetrics{ + CachedTokens: int(cachedTokens), + InputTokens: int(inputTokens), + OutputTokens: int(outputTokens), + PromptPerSecond: promptPerSecond, + TokensPerSecond: tokensPerSecond, + }, + DurationMs: durationMs, + } +} + +// decompressBody decompresses the body based on the Content-Encoding header. +func decompressBody(body []byte, encoding string) ([]byte, error) { + switch strings.ToLower(strings.TrimSpace(encoding)) { + case "gzip": + reader, err := gzip.NewReader(bytes.NewReader(body)) + if err != nil { + return nil, err + } + defer reader.Close() + return io.ReadAll(reader) + case "deflate": + reader := flate.NewReader(bytes.NewReader(body)) + defer reader.Close() + return io.ReadAll(reader) + default: + return body, nil + } +} + +// filterAcceptEncoding filters Accept-Encoding to only gzip/deflate so response +// bodies remain decompressible for metrics parsing. +func filterAcceptEncoding(acceptEncoding string) string { + if acceptEncoding == "" { + return "" + } + + supported := map[string]bool{"gzip": true, "deflate": true} + var filtered []string + for part := range strings.SplitSeq(acceptEncoding, ",") { + encoding, _, _ := strings.Cut(strings.TrimSpace(part), ";") + if supported[strings.ToLower(encoding)] { + filtered = append(filtered, strings.TrimSpace(part)) + } + } + return strings.Join(filtered, ", ") +} + +// responseBodyCopier tees the upstream response to the client while buffering +// it for metrics parsing. Status defaults to 200 until WriteHeader is called. +type responseBodyCopier struct { + http.ResponseWriter + body *bytes.Buffer + tee io.Writer + status int + wroteHeader bool + start time.Time +} + +func newBodyCopier(w http.ResponseWriter) *responseBodyCopier { + buf := &bytes.Buffer{} + return &responseBodyCopier{ + ResponseWriter: w, + body: buf, + tee: io.MultiWriter(w, buf), + status: http.StatusOK, + start: time.Now(), + } +} + +func (w *responseBodyCopier) Write(b []byte) (int, error) { + if !w.wroteHeader { + w.WriteHeader(http.StatusOK) + } + return w.tee.Write(b) +} + +func (w *responseBodyCopier) WriteHeader(statusCode int) { + if w.wroteHeader { + return + } + w.wroteHeader = true + w.status = statusCode + w.ResponseWriter.WriteHeader(statusCode) +} + +// Flush forwards to the underlying writer so streaming responses still flush. +func (w *responseBodyCopier) Flush() { + if f, ok := w.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +func (w *responseBodyCopier) Status() int { return w.status } +func (w *responseBodyCopier) StartTime() time.Time { return w.start } diff --git a/internal/server/metrics_middleware.go b/internal/server/metrics_middleware.go new file mode 100644 index 0000000..b52a705 --- /dev/null +++ b/internal/server/metrics_middleware.go @@ -0,0 +1,62 @@ +package server + +import ( + "bytes" + "io" + "net/http" + + "github.com/mostlygeek/llama-swap/internal/chain" + "github.com/mostlygeek/llama-swap/internal/config" + "github.com/mostlygeek/llama-swap/internal/router" +) + +// CreateMetricsMiddleware returns middleware that records token metrics for +// model-dispatched POST requests. It resolves the model, tees the response into +// a buffer, and parses token usage once the upstream handler returns. +func CreateMetricsMiddleware(mm *metricsMonitor, cfg config.Config) chain.Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if mm == nil || r.Method != http.MethodPost { + next.ServeHTTP(w, r) + return + } + + // Resolve the model now so downstream dispatch hits the context + // fast path; FetchContext restores the request body. + data, err := router.FetchContext(r, cfg) + if err != nil { + router.SendError(w, r, router.ErrNoModelInContext) + return + } + + // Buffer the request body/headers for capture before dispatch + // consumes them. + cf := captureFieldsFor(r.URL.Path) + var reqBody []byte + var reqHeaders map[string]string + if mm.enableCaptures { + if cf&captureReqBody != 0 && r.Body != nil { + if buffered, err := io.ReadAll(r.Body); err == nil { + reqBody = buffered + r.Body.Close() + r.Body = io.NopCloser(bytes.NewReader(reqBody)) + } + } + if cf&captureReqHeaders != 0 { + reqHeaders = headerMap(r.Header) + redactHeaders(reqHeaders) + } + } + + // Restrict Accept-Encoding to encodings we can decompress so the + // buffered response body stays parseable. + if ae := r.Header.Get("Accept-Encoding"); ae != "" { + r.Header.Set("Accept-Encoding", filterAcceptEncoding(ae)) + } + + recorder := newBodyCopier(w) + next.ServeHTTP(recorder, r) + mm.record(data.ModelID, r, recorder, cf, reqBody, reqHeaders) + }) + } +} diff --git a/internal/server/metrics_test.go b/internal/server/metrics_test.go new file mode 100644 index 0000000..04412cd --- /dev/null +++ b/internal/server/metrics_test.go @@ -0,0 +1,74 @@ +package server + +import ( + "testing" + "time" + + "github.com/tidwall/gjson" +) + +func TestServer_ParseMetrics_ChatCompletions(t *testing.T) { + body := `{"usage":{"prompt_tokens":12,"completion_tokens":7,"prompt_tokens_details":{"cached_tokens":4}}}` + parsed := gjson.Parse(body) + entry, err := parseMetrics("m", time.Now(), parsed.Get("usage"), parsed.Get("timings")) + if err != nil { + t.Fatalf("parseMetrics: %v", err) + } + if entry.Tokens.InputTokens != 12 || entry.Tokens.OutputTokens != 7 || entry.Tokens.CachedTokens != 4 { + t.Fatalf("tokens = %+v", entry.Tokens) + } +} + +func TestServer_ParseMetrics_Timings(t *testing.T) { + body := `{"timings":{"prompt_n":20,"predicted_n":50,"prompt_per_second":100.0,"predicted_per_second":40.0,"prompt_ms":200,"predicted_ms":1250,"cache_n":8}}` + parsed := gjson.Parse(body) + entry, err := parseMetrics("m", time.Now(), parsed.Get("usage"), parsed.Get("timings")) + if err != nil { + t.Fatalf("parseMetrics: %v", err) + } + if entry.Tokens.InputTokens != 20 || entry.Tokens.OutputTokens != 50 || entry.Tokens.CachedTokens != 8 { + t.Fatalf("tokens = %+v", entry.Tokens) + } + if entry.Tokens.TokensPerSecond != 40.0 || entry.Tokens.PromptPerSecond != 100.0 { + t.Fatalf("rates = %+v", entry.Tokens) + } + if entry.DurationMs != 1450 { + t.Fatalf("DurationMs = %d, want 1450", entry.DurationMs) + } +} + +func TestServer_ProcessStreamingResponse(t *testing.T) { + body := []byte("data: {\"choices\":[{}]}\n\n" + + "data: {\"usage\":{\"prompt_tokens\":15,\"completion_tokens\":33}}\n\n" + + "data: [DONE]\n\n") + entry, err := processStreamingResponse("m", time.Now(), body) + if err != nil { + t.Fatalf("processStreamingResponse: %v", err) + } + if entry.Tokens.InputTokens != 15 || entry.Tokens.OutputTokens != 33 { + t.Fatalf("tokens = %+v", entry.Tokens) + } +} + +func TestServer_ProcessStreamingResponse_NoData(t *testing.T) { + if _, err := processStreamingResponse("m", time.Now(), []byte("data: [DONE]\n\n")); err == nil { + t.Fatal("expected error for stream with no usage data") + } +} + +func TestServer_ParseMetrics_Infill(t *testing.T) { + // /infill responses are arrays; timings live in the last element. + body := `[{"content":"a"},{"content":"b","timings":{"prompt_n":5,"predicted_n":9,"prompt_ms":10,"predicted_ms":20}}]` + parsed := gjson.Parse(body) + timings := parsed.Get("timings") + if arr := parsed.Array(); len(arr) > 0 { + timings = arr[len(arr)-1].Get("timings") + } + entry, err := parseMetrics("m", time.Now(), parsed.Get("usage"), timings) + if err != nil { + t.Fatalf("parseMetrics: %v", err) + } + if entry.Tokens.InputTokens != 5 || entry.Tokens.OutputTokens != 9 { + t.Fatalf("tokens = %+v", entry.Tokens) + } +} diff --git a/internal/server/server.go b/internal/server/server.go new file mode 100644 index 0000000..5a29a43 --- /dev/null +++ b/internal/server/server.go @@ -0,0 +1,290 @@ +package server + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/mostlygeek/llama-swap/internal/chain" + "github.com/mostlygeek/llama-swap/internal/config" + "github.com/mostlygeek/llama-swap/internal/logmon" + "github.com/mostlygeek/llama-swap/internal/perf" + "github.com/mostlygeek/llama-swap/internal/router" +) + +// Server owns the HTTP mux, cross-cutting middleware, and the local/peer model +// dispatch. It supersedes router.Server: it builds the local and peer routers +// directly and dispatches between them itself. +type Server struct { + cfg config.Config + + muxlog *logmon.Monitor + proxylog *logmon.Monitor + upstreamlog *logmon.Monitor + + perf *perf.Monitor + inflight *inflightCounter + metrics *metricsMonitor + build BuildInfo + + local router.LocalRouter + peer router.Router + + mux *http.ServeMux + handler http.Handler + + shutdownCtx context.Context + shutdownFn context.CancelFunc + shuttingDown atomic.Bool +} + +// modelPostJSONRoutes are endpoints with a model id in the JSON request body. +var modelPostJSONRoutes = []string{ + "/v1/chat/completions", + "/v1/responses", + "/v1/completions", + "/v1/messages", + "/v1/messages/count_tokens", + "/v1/embeddings", + "/reranking", + "/rerank", + "/v1/rerank", + "/v1/reranking", + "/infill", + "/completion", + "/v1/audio/speech", + "/v1/audio/voices", + "/v1/images/generations", + "/sdapi/v1/txt2img", + "/sdapi/v1/img2img", + + // versionless routes, the /v/ is stripped before the request is forwarded upstream + // see issue #728 + "/v/chat/completions", + "/v/responses", + "/v/completions", + "/v/messages", + "/v/messages/count_tokens", + "/v/embeddings", + "/v/rerank", + "/v/reranking", +} + +// modelPostFormRoutes are multipart/form-data endpoints with a model id in the form data +var modelPostFormRoutes = []string{ + "/v1/audio/transcriptions", + "/v1/images/edits", +} + +// modelGetRoutes are model-dispatched GET endpoints (the model arrives as a +// query parameter). +var modelGetRoutes = []string{ + "/v1/audio/voices", + "/sdapi/v1/loras", +} + +// BuildInfo carries version metadata surfaced by GET /api/version. +type BuildInfo struct { + Version string + Commit string + Date string +} + +func New(cfg config.Config, muxlog *logmon.Monitor, proxylog *logmon.Monitor, upstreamlog *logmon.Monitor, perfMon *perf.Monitor, build BuildInfo) (*Server, error) { + var local router.LocalRouter + var err error + + if cfg.Matrix != nil { + local, err = router.NewMatrix(cfg, proxylog, upstreamlog) + if err != nil { + return nil, fmt.Errorf("creating matrix router: %w", err) + } + } else { + local, err = router.NewGroup(cfg, proxylog, upstreamlog) + if err != nil { + return nil, fmt.Errorf("creating group router: %w", err) + } + } + + peer, err := router.NewPeer(cfg, proxylog) + if err != nil { + return nil, fmt.Errorf("creating peer router: %w", err) + } + + shutdownCtx, shutdownFn := context.WithCancel(context.Background()) + s := &Server{ + cfg: cfg, + muxlog: muxlog, + proxylog: proxylog, + upstreamlog: upstreamlog, + perf: perfMon, + inflight: &inflightCounter{}, + metrics: newMetricsMonitor(proxylog, cfg.MetricsMaxInMemory, cfg.CaptureBuffer), + build: build, + local: local, + peer: peer, + shutdownCtx: shutdownCtx, + shutdownFn: shutdownFn, + } + s.routes() + s.startPreload() + return s, nil +} + +// localPeerHandler dispatches a model-routed request to the local or peer +// router. The model is resolved once via router.FetchContext. +func (s *Server) localPeerHandler(w http.ResponseWriter, r *http.Request) { + stripVersionPrefix(r) + + data, err := router.FetchContext(r, s.cfg) + if err != nil { + router.SendError(w, r, router.ErrNoModelInContext) + return + } + + switch { + case s.local.Handles(data.ModelID): + s.proxylog.Debugf("dispatch: using local process for model: %s", data.ModelID) + s.local.ServeHTTP(w, r) + case s.peer.Handles(data.ModelID): + s.proxylog.Debugf("dispatch: using peer for model: %s", data.ModelID) + s.peer.ServeHTTP(w, r) + default: + router.SendError(w, r, router.ErrNoRouterFound) + } +} + +// stripVersionPrefix rewrites versionless /v/... requests to their /... form +// before forwarding upstream (issue #728). +func stripVersionPrefix(r *http.Request) { + if strings.HasPrefix(r.URL.Path, "/v/") { + r.URL.Path = strings.TrimPrefix(r.URL.Path, "/v") + } +} + +// routes builds the mux, registers every route, and wraps the mux with the +// global CORS middleware. +func (s *Server) routes() { + authMW := CreateAuthMiddleware(s.cfg) + filterMW := CreateFilterMiddleware(s.cfg) + formFilterMW := CreateFormFilterMiddleware(s.cfg) + + // Model-dispatched routes get auth + per-model concurrency limiting + body + // filters + in-flight tracking + token metrics. concurrencyMW rejects with + // 429 before the body filters do any rewrite work. filterMW rewrites JSON + // bodies and formFilterMW rewrites multipart bodies; each is a no-op for the + // other's Content-Type. Both run before the metrics middleware so it buffers + // the rewritten body. + modelChain := chain.New( + authMW, + CreateConcurrencyMiddleware(s.cfg), + filterMW, + formFilterMW, + CreateInflightMiddleware(s.inflight), + CreateMetricsMiddleware(s.metrics, s.cfg), + ) + // Custom endpoints only need auth. + apiChain := chain.New(authMW) + + mux := http.NewServeMux() + dispatch := http.HandlerFunc(s.localPeerHandler) + + for _, path := range modelPostJSONRoutes { + mux.Handle("POST "+path, modelChain.Then(dispatch)) + } + for _, path := range modelPostFormRoutes { + mux.Handle("POST "+path, modelChain.Then(dispatch)) + } + for _, path := range modelGetRoutes { + mux.Handle("GET "+path, modelChain.Then(dispatch)) + } + + // llama-swap API + custom endpoints. + mux.Handle("GET /v1/models", apiChain.ThenFunc(s.handleListModels)) + mux.Handle("GET /logs", apiChain.ThenFunc(s.handleLogs)) + mux.Handle("GET /logs/stream", apiChain.ThenFunc(s.handleLogStream)) + mux.Handle("GET /logs/stream/{logMonitorID...}", apiChain.ThenFunc(s.handleLogStream)) + + mux.HandleFunc("GET /health", handleHealth) + mux.HandleFunc("GET /wol-health", handleHealth) + mux.HandleFunc("GET /{$}", handleRootRedirect) + + // Embedded UI. + mux.HandleFunc("GET /ui/", s.handleUI) + mux.HandleFunc("GET /favicon.ico", s.handleFavicon) + + // Prometheus metrics (no auth, matches the legacy endpoint). + mux.HandleFunc("GET /metrics", s.handleMetrics) + + // Operations endpoints. + mux.Handle("GET /unload", apiChain.ThenFunc(s.handleUnload)) + mux.Handle("GET /running", apiChain.ThenFunc(s.handleRunning)) + + // Upstream passthrough. + mux.HandleFunc("GET /upstream", handleUpstreamRedirect) + mux.Handle("/upstream/{upstreamPath...}", apiChain.ThenFunc(s.handleUpstream)) + + // API group (API-key protected) consumed by the UI. + mux.Handle("POST /api/models/unload", apiChain.ThenFunc(s.handleAPIUnloadAll)) + mux.Handle("POST /api/models/unload/{model...}", apiChain.ThenFunc(s.handleAPIUnloadModel)) + mux.Handle("GET /api/events", apiChain.ThenFunc(s.handleAPIEvents)) + mux.Handle("GET /api/metrics", apiChain.ThenFunc(s.handleAPIMetrics)) + mux.Handle("GET /api/performance", apiChain.ThenFunc(s.handleAPIPerformance)) + mux.Handle("GET /api/version", apiChain.ThenFunc(s.handleAPIVersion)) + mux.Handle("GET /api/captures/{id}", apiChain.ThenFunc(s.handleAPICapture)) + + s.mux = mux + s.handler = chain.New(CreateRequestLogMiddleware(s.proxylog), CreateCORSMiddleware()).Then(mux) +} + +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + s.handler.ServeHTTP(w, r) +} + +// CloseStreams cancels long-lived response streams (Server-Sent Events) so a +// graceful httpServer.Shutdown can drain without blocking on them. It does not +// tear down routers; call Shutdown for that. Safe to call repeatedly. +func (s *Server) CloseStreams() { + s.shutdownFn() +} + +// Shutdown stops the local and peer routers in parallel. It is idempotent; +// repeated calls return nil without re-running shutdown. +// +// Callers must drain inflight HTTP requests (httpServer.Shutdown) before +// calling this, otherwise inflight requests 502 when their processes are torn +// down. Call CloseStreams before httpServer.Shutdown so SSE streams do not +// block the drain. +func (s *Server) Shutdown(timeout time.Duration) error { + if !s.shuttingDown.CompareAndSwap(false, true) { + return nil + } + s.shutdownFn() + + var wg sync.WaitGroup + var mu sync.Mutex + var errs []error + + for _, rt := range []router.Router{s.local, s.peer} { + if rt == nil { + continue + } + wg.Add(1) + go func(rt router.Router) { + defer wg.Done() + if err := rt.Shutdown(timeout); err != nil { + mu.Lock() + errs = append(errs, err) + mu.Unlock() + } + }(rt) + } + + wg.Wait() + return errors.Join(errs...) +} diff --git a/internal/server/server_test.go b/internal/server/server_test.go new file mode 100644 index 0000000..3f9bbde --- /dev/null +++ b/internal/server/server_test.go @@ -0,0 +1,331 @@ +package server + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/mostlygeek/llama-swap/internal/config" + "github.com/mostlygeek/llama-swap/internal/event" + "github.com/mostlygeek/llama-swap/internal/logmon" + "github.com/mostlygeek/llama-swap/internal/process" + "github.com/mostlygeek/llama-swap/internal/router" + "github.com/mostlygeek/llama-swap/internal/shared" +) + +// stubRouter is a minimal router.LocalRouter for Server dispatch tests. +type stubRouter struct { + models map[string]bool + response string + shutdownCalls atomic.Int32 + running map[string]process.ProcessState + unloadCalls atomic.Int32 + loggers map[string]*logmon.Monitor +} + +func newStubRouter(models []string, response string) *stubRouter { + m := make(map[string]bool, len(models)) + for _, id := range models { + m[id] = true + } + return &stubRouter{models: m, response: response} +} + +func (s *stubRouter) Handles(model string) bool { return s.models[model] } +func (s *stubRouter) Shutdown(_ time.Duration) error { s.shutdownCalls.Add(1); return nil } +func (s *stubRouter) ServeHTTP(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(s.response)) +} + +func (s *stubRouter) RunningModels() map[string]process.ProcessState { return s.running } +func (s *stubRouter) Unload(_ time.Duration, _ ...string) { s.unloadCalls.Add(1) } +func (s *stubRouter) ProcessLogger(modelID string) (*logmon.Monitor, bool) { + if s.loggers != nil { + if lg, ok := s.loggers[modelID]; ok { + return lg, true + } + } + return nil, false +} + +// newTestServer wires a Server with stub routers and a built mux. +func newTestServer(local router.LocalRouter, peer router.Router) *Server { + ctx, cancel := context.WithCancel(context.Background()) + proxylog := logmon.NewWriter(io.Discard) + s := &Server{ + cfg: config.Config{}, + muxlog: logmon.NewWriter(io.Discard), + proxylog: proxylog, + upstreamlog: logmon.NewWriter(io.Discard), + inflight: &inflightCounter{}, + metrics: newMetricsMonitor(proxylog, 0, 0), + local: local, + peer: peer, + shutdownCtx: ctx, + shutdownFn: cancel, + } + s.routes() + return s +} + +func chatRequest(model string) *http.Request { + body := strings.NewReader(`{"model":"` + model + `"}`) + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", body) + req.Header.Set("Content-Type", "application/json") + return req +} + +func TestServer_New_GroupConfig(t *testing.T) { + discard := logmon.NewWriter(io.Discard) + s, err := New(config.Config{HealthCheckTimeout: 15}, discard, discard, discard, nil, BuildInfo{}) + if err != nil { + t.Fatalf("New (group): %v", err) + } + if err := s.Shutdown(time.Second); err != nil { + t.Fatalf("Shutdown: %v", err) + } +} + +func TestServer_New_MatrixConfig(t *testing.T) { + discard := logmon.NewWriter(io.Discard) + cfg := config.Config{HealthCheckTimeout: 15, Matrix: &config.MatrixConfig{}} + s, err := New(cfg, discard, discard, discard, nil, BuildInfo{}) + if err != nil { + t.Fatalf("New (matrix): %v", err) + } + if err := s.Shutdown(time.Second); err != nil { + t.Fatalf("Shutdown: %v", err) + } +} + +func TestServer_RouteToLocalModel(t *testing.T) { + s := newTestServer( + newStubRouter([]string{"local-model"}, "local response"), + newStubRouter(nil, ""), + ) + + w := httptest.NewRecorder() + s.ServeHTTP(w, chatRequest("local-model")) + + if w.Code != http.StatusOK { + t.Fatalf("status=%d body=%q", w.Code, w.Body.String()) + } + if w.Body.String() != "local response" { + t.Errorf("body=%q want %q", w.Body.String(), "local response") + } +} + +func TestServer_RouteToPeerModel(t *testing.T) { + s := newTestServer( + newStubRouter(nil, ""), + newStubRouter([]string{"peer-model"}, "peer response"), + ) + + w := httptest.NewRecorder() + s.ServeHTTP(w, chatRequest("peer-model")) + + if w.Code != http.StatusOK { + t.Fatalf("status=%d body=%q", w.Code, w.Body.String()) + } + if w.Body.String() != "peer response" { + t.Errorf("body=%q want %q", w.Body.String(), "peer response") + } +} + +func TestServer_UnknownModelReturns404(t *testing.T) { + s := newTestServer( + newStubRouter([]string{"local-model"}, ""), + newStubRouter(nil, ""), + ) + + w := httptest.NewRecorder() + s.ServeHTTP(w, chatRequest("unknown-model")) + + if w.Code != http.StatusNotFound { + t.Errorf("status=%d want 404 body=%q", w.Code, w.Body.String()) + } +} + +func TestServer_UnknownPathReturns404(t *testing.T) { + s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, "")) + + w := httptest.NewRecorder() + s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/does-not-exist", nil)) + + if w.Code != http.StatusNotFound { + t.Errorf("status=%d want 404", w.Code) + } +} + +func TestServer_Health(t *testing.T) { + s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, "")) + + for _, path := range []string{"/health", "/wol-health"} { + w := httptest.NewRecorder() + s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, path, nil)) + if w.Code != http.StatusOK || w.Body.String() != "OK" { + t.Errorf("%s: status=%d body=%q", path, w.Code, w.Body.String()) + } + } +} + +func TestServer_CORSPreflight(t *testing.T) { + s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, "")) + + req := httptest.NewRequest(http.MethodOptions, "/v1/chat/completions", nil) + w := httptest.NewRecorder() + s.ServeHTTP(w, req) + + if w.Code != http.StatusNoContent { + t.Fatalf("status=%d want 204", w.Code) + } + if got := w.Header().Get("Access-Control-Allow-Origin"); got != "*" { + t.Errorf("Access-Control-Allow-Origin=%q want *", got) + } +} + +func TestServer_Unload(t *testing.T) { + local := newStubRouter([]string{"m1"}, "") + s := newTestServer(local, newStubRouter(nil, "")) + + w := httptest.NewRecorder() + s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/unload", nil)) + + if w.Code != http.StatusOK || w.Body.String() != "OK" { + t.Fatalf("status=%d body=%q", w.Code, w.Body.String()) + } + if got := local.unloadCalls.Load(); got != 1 { + t.Errorf("unloadCalls=%d want 1", got) + } +} + +func TestServer_Running(t *testing.T) { + local := newStubRouter([]string{"m1"}, "") + local.running = map[string]process.ProcessState{"m1": process.StateReady} + s := newTestServer(local, newStubRouter(nil, "")) + s.cfg = config.Config{Models: map[string]config.ModelConfig{ + "m1": { + Cmd: "llama-server", + Proxy: "http://localhost:9999", + UnloadAfter: 300, + Name: "Model One", + Description: "the first model", + }, + }} + + w := httptest.NewRecorder() + s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/running", nil)) + + if w.Code != http.StatusOK { + t.Fatalf("status=%d body=%q", w.Code, w.Body.String()) + } + + var resp struct { + Running []runningModel `json:"running"` + } + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("decode: %v body=%q", err, w.Body.String()) + } + if len(resp.Running) != 1 { + t.Fatalf("running=%v want 1 entry", resp.Running) + } + want := runningModel{ + Model: "m1", + State: "ready", + Cmd: "llama-server", + Proxy: "http://localhost:9999", + TTL: 300, + Name: "Model One", + Description: "the first model", + } + if resp.Running[0] != want { + t.Errorf("got %+v want %+v", resp.Running[0], want) + } +} + +func TestServer_Preload(t *testing.T) { + local := newStubRouter([]string{"m1"}, "ok") + s := newTestServer(local, newStubRouter(nil, "")) + s.cfg = config.Config{Hooks: config.HooksConfig{ + OnStartup: config.HookOnStartup{Preload: []string{"m1"}}, + }} + + got := make(chan shared.ModelPreloadedEvent, 1) + cancel := event.On(func(e shared.ModelPreloadedEvent) { got <- e }) + defer cancel() + + s.startPreload() + + select { + case e := <-got: + if e.ModelName != "m1" || !e.Success { + t.Errorf("event=%+v want {ModelName:m1 Success:true}", e) + } + case <-time.After(2 * time.Second): + t.Fatal("preload event not received") + } +} + +func TestServer_Shutdown_StopsRoutersAndIsIdempotent(t *testing.T) { + local := newStubRouter([]string{"local-model"}, "") + peer := newStubRouter(nil, "") + s := newTestServer(local, peer) + + if err := s.Shutdown(time.Second); err != nil { + t.Fatalf("Shutdown: %v", err) + } + if err := s.Shutdown(time.Second); err != nil { + t.Fatalf("second Shutdown: %v", err) + } + if got := local.shutdownCalls.Load(); got != 1 { + t.Errorf("local shutdownCalls=%d want 1", got) + } + if got := peer.shutdownCalls.Load(); got != 1 { + t.Errorf("peer shutdownCalls=%d want 1", got) + } +} + +func TestServer_LogStream_ModelID(t *testing.T) { + buf := logmon.NewWriter(io.Discard) + buf.Write([]byte("hello from model")) + + local := newStubRouter([]string{"mymodel"}, "") + local.loggers = map[string]*logmon.Monitor{"mymodel": buf} + + s := newTestServer(local, newStubRouter(nil, "")) + s.cfg = config.Config{Models: map[string]config.ModelConfig{"mymodel": {}}} + + // Pre-cancel the context so the streaming loop exits immediately after + // flushing history. + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + req := httptest.NewRequest(http.MethodGet, "/logs/stream/mymodel", nil).WithContext(ctx) + w := httptest.NewRecorder() + s.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status=%d body=%q", w.Code, w.Body.String()) + } + if got := w.Body.String(); got != "hello from model" { + t.Errorf("body=%q want %q", got, "hello from model") + } +} + +func TestServer_LogStream_UnknownID_Returns400(t *testing.T) { + s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, "")) + + w := httptest.NewRecorder() + s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/logs/stream/no-such-model", nil)) + + if w.Code != http.StatusBadRequest { + t.Errorf("status=%d want 400", w.Code) + } +} diff --git a/internal/server/ui.go b/internal/server/ui.go new file mode 100644 index 0000000..0d59c71 --- /dev/null +++ b/internal/server/ui.go @@ -0,0 +1,111 @@ +package server + +import ( + "embed" + "io/fs" + "net/http" + "path" + "strings" +) + +// uiStaticFS holds the embedded UI build. The build is copied into ui_dist by +// the Makefile's `ui` target; placeholder.txt keeps the embed valid before a +// build has run. +// +//go:embed ui_dist +var uiStaticFS embed.FS + +// uiFS is the embedded UI rooted at ui_dist. +var uiFS = func() http.FileSystem { + sub, err := fs.Sub(uiStaticFS, "ui_dist") + if err != nil { + panic(err) + } + return http.FS(sub) +}() + +// selectEncoding chooses the best pre-compressed encoding the client accepts. +// It returns the encoding ("br" or "gzip") and the matching file extension. +func selectEncoding(acceptEncoding string) (encoding, ext string) { + if acceptEncoding == "" { + return "", "" + } + for _, part := range strings.Split(acceptEncoding, ",") { + if strings.TrimSpace(strings.SplitN(part, ";", 2)[0]) == "br" { + return "br", ".br" + } + } + for _, part := range strings.Split(acceptEncoding, ",") { + if strings.TrimSpace(strings.SplitN(part, ";", 2)[0]) == "gzip" { + return "gzip", ".gz" + } + } + return "", "" +} + +// serveCompressedFile serves name from fsys, preferring a pre-compressed +// sibling (name+".br" / name+".gz") when the client accepts it. It returns an +// error without writing a response when name cannot be served, so callers can +// fall back (e.g. SPA routing). +func serveCompressedFile(fsys http.FileSystem, w http.ResponseWriter, r *http.Request, name string) error { + if encoding, ext := selectEncoding(r.Header.Get("Accept-Encoding")); encoding != "" { + if cf, err := fsys.Open(name + ext); err == nil { + defer cf.Close() + if stat, err := cf.Stat(); err == nil && !stat.IsDir() { + w.Header().Set("Content-Encoding", encoding) + w.Header().Add("Vary", "Accept-Encoding") + http.ServeContent(w, r, name, stat.ModTime(), cf) + return nil + } + } + } + + file, err := fsys.Open(name) + if err != nil { + return err + } + defer file.Close() + + stat, err := file.Stat() + if err != nil { + return err + } + if stat.IsDir() { + return fs.ErrNotExist + } + + http.ServeContent(w, r, name, stat.ModTime(), file) + return nil +} + +// handleUI serves the embedded SPA under /ui/. +func (s *Server) handleUI(w http.ResponseWriter, r *http.Request) { + serveUI(uiFS, w, r) +} + +// serveUI serves the SPA from fsys. Real files are served with compression +// support; unknown paths without a file extension fall back to index.html so +// client-side routing works. +func serveUI(fsys http.FileSystem, w http.ResponseWriter, r *http.Request) { + name := strings.TrimPrefix(r.URL.Path, "/ui/") + if name == "" { + name = "index.html" + } + + if err := serveCompressedFile(fsys, w, r, name); err != nil { + if strings.Contains(path.Base(name), ".") { + http.NotFound(w, r) + return + } + if err := serveCompressedFile(fsys, w, r, "index.html"); err != nil { + http.NotFound(w, r) + } + } +} + +// handleFavicon serves /favicon.ico from the embedded UI build. +func (s *Server) handleFavicon(w http.ResponseWriter, r *http.Request) { + if err := serveCompressedFile(uiFS, w, r, "favicon.ico"); err != nil { + http.NotFound(w, r) + } +} diff --git a/internal/server/ui_dist/placeholder.txt b/internal/server/ui_dist/placeholder.txt new file mode 100644 index 0000000..de4afd4 --- /dev/null +++ b/internal/server/ui_dist/placeholder.txt @@ -0,0 +1 @@ +placeholder so //go:embed ui_dist succeeds before the UI is built diff --git a/internal/server/ui_test.go b/internal/server/ui_test.go new file mode 100644 index 0000000..6a9c5d6 --- /dev/null +++ b/internal/server/ui_test.go @@ -0,0 +1,92 @@ +package server + +import ( + "net/http" + "net/http/httptest" + "testing" + "testing/fstest" +) + +func TestServer_SelectEncoding(t *testing.T) { + cases := []struct { + accept string + encoding string + ext string + }{ + {"", "", ""}, + {"gzip", "gzip", ".gz"}, + {"gzip, deflate, br", "br", ".br"}, + {"deflate", "", ""}, + {"br;q=1.0, gzip;q=0.8", "br", ".br"}, + } + for _, c := range cases { + enc, ext := selectEncoding(c.accept) + if enc != c.encoding || ext != c.ext { + t.Errorf("selectEncoding(%q) = (%q, %q), want (%q, %q)", c.accept, enc, ext, c.encoding, c.ext) + } + } +} + +func uiTestFS() http.FileSystem { + return http.FS(fstest.MapFS{ + "index.html": {Data: []byte("app")}, + "app.js": {Data: []byte("plain")}, + "app.js.br": {Data: []byte("brotli")}, + "app.js.gz": {Data: []byte("gzipped")}, + "favicon.ico": {Data: []byte("icon")}, + }) +} + +func serveUIRequest(t *testing.T, path, acceptEncoding string) *httptest.ResponseRecorder { + t.Helper() + req := httptest.NewRequest(http.MethodGet, path, nil) + if acceptEncoding != "" { + req.Header.Set("Accept-Encoding", acceptEncoding) + } + w := httptest.NewRecorder() + serveUI(uiTestFS(), w, req) + return w +} + +func TestServer_ServeUI_File(t *testing.T) { + w := serveUIRequest(t, "/ui/app.js", "") + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", w.Code) + } + if w.Body.String() != "plain" { + t.Errorf("body = %q, want plain", w.Body.String()) + } +} + +func TestServer_ServeUI_Brotli(t *testing.T) { + w := serveUIRequest(t, "/ui/app.js", "gzip, br") + if got := w.Header().Get("Content-Encoding"); got != "br" { + t.Fatalf("Content-Encoding = %q, want br", got) + } + if w.Body.String() != "brotli" { + t.Errorf("body = %q, want brotli", w.Body.String()) + } +} + +func TestServer_ServeUI_IndexAndRoot(t *testing.T) { + for _, path := range []string{"/ui/", "/ui/index.html"} { + w := serveUIRequest(t, path, "") + if w.Code != http.StatusOK || w.Body.String() != "app" { + t.Errorf("%s: status=%d body=%q", path, w.Code, w.Body.String()) + } + } +} + +func TestServer_ServeUI_SPAFallback(t *testing.T) { + w := serveUIRequest(t, "/ui/models", "") + if w.Code != http.StatusOK || w.Body.String() != "app" { + t.Errorf("SPA fallback: status=%d body=%q", w.Code, w.Body.String()) + } +} + +func TestServer_ServeUI_MissingFile(t *testing.T) { + w := serveUIRequest(t, "/ui/missing.js", "") + if w.Code != http.StatusNotFound { + t.Errorf("status = %d, want 404", w.Code) + } +} diff --git a/internal/shared/events.go b/internal/shared/events.go new file mode 100644 index 0000000..f006928 --- /dev/null +++ b/internal/shared/events.go @@ -0,0 +1,52 @@ +package shared + +const ProcessStateChangeEventID = 0x01 +const ConfigFileChangedEventID = 0x03 +const ActivityLogEventID = 0x05 +const ModelPreloadedEventID = 0x06 +const InFlightRequestsEventID = 0x07 + +// ProcessStateChangeEvent is emitted whenever a process transitions between +// lifecycle states. States are carried as strings so this package stays a leaf +// (no import of internal/process). +type ProcessStateChangeEvent struct { + ProcessName string + OldState string + NewState string +} + +func (e ProcessStateChangeEvent) Type() uint32 { + return ProcessStateChangeEventID +} + +type ReloadingState int + +const ( + ReloadingStateStart ReloadingState = iota + ReloadingStateEnd +) + +type ConfigFileChangedEvent struct { + State ReloadingState +} + +func (e ConfigFileChangedEvent) Type() uint32 { + return ConfigFileChangedEventID +} + +type ModelPreloadedEvent struct { + ModelName string + Success bool +} + +func (e ModelPreloadedEvent) Type() uint32 { + return ModelPreloadedEventID +} + +type InFlightRequestsEvent struct { + Total int +} + +func (e InFlightRequestsEvent) Type() uint32 { + return InFlightRequestsEventID +} diff --git a/proxy/configwatcher/watcher.go b/internal/watcher/watcher.go similarity index 100% rename from proxy/configwatcher/watcher.go rename to internal/watcher/watcher.go diff --git a/proxy/configwatcher/watcher_test.go b/internal/watcher/watcher_test.go similarity index 100% rename from proxy/configwatcher/watcher_test.go rename to internal/watcher/watcher_test.go diff --git a/llama-swap.go b/llama-swap.go index 643db5e..7171436 100644 --- a/llama-swap.go +++ b/llama-swap.go @@ -2,8 +2,10 @@ package main import ( "context" + "errors" "flag" "fmt" + "log/slog" "net/http" "os" "os/signal" @@ -13,237 +15,286 @@ import ( "syscall" "time" - "github.com/gin-gonic/gin" - "github.com/mostlygeek/llama-swap/event" + "github.com/mostlygeek/llama-swap/internal/config" + "github.com/mostlygeek/llama-swap/internal/event" "github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/perf" - "github.com/mostlygeek/llama-swap/proxy" - "github.com/mostlygeek/llama-swap/proxy/config" - "github.com/mostlygeek/llama-swap/proxy/configwatcher" + "github.com/mostlygeek/llama-swap/internal/server" + "github.com/mostlygeek/llama-swap/internal/shared" + "github.com/mostlygeek/llama-swap/internal/watcher" ) var ( - version string = "0" - commit string = "abcd1234" - date string = "unknown" + version = "0" + commit = "abcd1234" + date = "unknown" ) +const shutdownTimeout = 30 * time.Second + +// logTimeFormats maps the cfg.LogTimeFormat value to a Go time layout. An +// unset or unrecognised value yields "" — no timestamp prefix. +var logTimeFormats = map[string]string{ + "ansic": time.ANSIC, + "unixdate": time.UnixDate, + "rubydate": time.RubyDate, + "rfc822": time.RFC822, + "rfc822z": time.RFC822Z, + "rfc850": time.RFC850, + "rfc1123": time.RFC1123, + "rfc1123z": time.RFC1123Z, + "rfc3339": time.RFC3339, + "rfc3339nano": time.RFC3339Nano, + "kitchen": time.Kitchen, + "stamp": time.Stamp, + "stampmilli": time.StampMilli, + "stampmicro": time.StampMicro, + "stampnano": time.StampNano, +} + func main() { - // Define a command-line flag for the port - configPath := flag.String("config", "config.yaml", "config file name") - listenStr := flag.String("listen", "", "listen ip/port") - certFile := flag.String("tls-cert-file", "", "TLS certificate file") - keyFile := flag.String("tls-key-file", "", "TLS key file") - showVersion := flag.Bool("version", false, "show version of build") - watchConfig := flag.Bool("watch-config", false, "Automatically reload config file on change") - mainLogger := logmon.New() + flagConfig := flag.String("config", "", "path to config file (required)") + flagListen := flag.String("listen", "", "listen address (default :8080 or :8443 for TLS)") + flagCertFile := flag.String("tls-cert-file", "", "TLS certificate file") + flagKeyFile := flag.String("tls-key-file", "", "TLS key file") + flagVersion := flag.Bool("version", false, "show version and exit") + flagWatchConfig := flag.Bool("watch-config", false, "reload config on file change") + flag.Parse() - flag.Parse() // Parse the command-line flags - - if *showVersion { - fmt.Printf("version: %s (%s), built at %s", version, commit, date) + if *flagVersion { + fmt.Printf("version: %s (%s), built at %s\n", version, commit, date) os.Exit(0) } - conf, err := config.LoadConfig(*configPath) - if err != nil { - mainLogger.Errorf("Error loading config: %v", err) + if *flagConfig == "" { + slog.Error("-config is required") os.Exit(1) } - if len(conf.Profiles) > 0 { - mainLogger.Warn("Profile functionality has been removed in favor of Groups. See the README for more information.") - } - - switch strings.ToLower(strings.TrimSpace(conf.LogLevel)) { - case "debug": - mainLogger.SetLogLevel(logmon.LevelDebug) - case "info": - mainLogger.SetLogLevel(logmon.LevelInfo) - case "warn": - mainLogger.SetLogLevel(logmon.LevelWarn) - case "error": - mainLogger.SetLogLevel(logmon.LevelError) - default: - mainLogger.SetLogLevel(logmon.LevelInfo) - } - - mainLogger.Debugf("PID: %d", os.Getpid()) - - if mode := os.Getenv("GIN_MODE"); mode != "" { - gin.SetMode(mode) - } else { - gin.SetMode(gin.ReleaseMode) - } - - // Validate TLS flags. - var useTLS = (*certFile != "" && *keyFile != "") - if (*certFile != "" && *keyFile == "") || - (*certFile == "" && *keyFile != "") { - fmt.Println("Error: Both --tls-cert-file and --tls-key-file must be provided for TLS.") + useTLS := *flagCertFile != "" || *flagKeyFile != "" + if (*flagCertFile != "" && *flagKeyFile == "") || (*flagCertFile == "" && *flagKeyFile != "") { + slog.Error("both -tls-cert-file and -tls-key-file must be provided for TLS") os.Exit(1) } - // Set default ports. - if *listenStr == "" { - defaultPort := ":8080" + listenAddr := *flagListen + if listenAddr == "" { if useTLS { - defaultPort = ":8443" + listenAddr = ":8443" + } else { + listenAddr = ":8080" } - listenStr = &defaultPort } - var mon *perf.Monitor - if !conf.Performance.Disabled { - mon, err = perf.New(conf.Performance, mainLogger) + configPath := *flagConfig + cfg, err := config.LoadConfig(configPath) + if err != nil { + slog.Error("failed to load config", "path", configPath, "error", err) + os.Exit(1) + } + + // Loggers are wired per cfg.LogToStdout: proxy/upstream feed muxLog, which + // owns the combined history served by /logs. They outlive config reloads, + // so a LogToStdout change requires a restart to take effect. + muxLog, proxyLog, upstreamLog := server.NewLoggers(cfg.LogToStdout) + + if len(cfg.Profiles) > 0 { + proxyLog.Warn("Profile functionality has been removed in favor of Groups. See the README for more information.") + } + + applyLogSettings := func(cfg config.Config) { + level := logmon.LevelInfo + switch strings.ToLower(strings.TrimSpace(cfg.LogLevel)) { + case "debug": + level = logmon.LevelDebug + case "warn": + level = logmon.LevelWarn + case "error": + level = logmon.LevelError + } + timeFormat := logTimeFormats[strings.ToLower(strings.TrimSpace(cfg.LogTimeFormat))] + for _, lg := range []*logmon.Monitor{proxyLog, upstreamLog} { + lg.SetLogLevel(level) + lg.SetLogTimeFormat(timeFormat) + } + } + + applyLogSettings(cfg) + proxyLog.Debugf("PID: %d", os.Getpid()) + + // perfMon outlives config reloads; its config is updated in place. + var perfMon *perf.Monitor + if !cfg.Performance.Disabled { + perfMon, err = perf.New(cfg.Performance, proxyLog) if err != nil { - mainLogger.Errorf("failed to create monitor: %s", err.Error()) + slog.Error("failed to create performance monitor", "error", err) os.Exit(1) } - mon.Start() + perfMon.Start() } else { - mainLogger.Info("performance monitoring is disabled") + proxyLog.Info("performance monitoring is disabled") } - // Setup channels for server management - exitChan := make(chan struct{}) - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP) + buildInfo := server.BuildInfo{Version: version, Commit: commit, Date: date} - // Context that bounds the lifetime of background watcher goroutines. - watcherCtx, watcherCancel := context.WithCancel(context.Background()) - - // Create server with initial handlergit - srv := &http.Server{ - Addr: *listenStr, + initialSrv, err := server.New(cfg, muxLog, proxyLog, upstreamLog, perfMon, buildInfo) + if err != nil { + slog.Error("failed to create server", "error", err) + os.Exit(1) } - // Support for watching config and reloading when it changes - reloading := false - var reloadMutex sync.Mutex - reloadProxyManager := func() { - reloadMutex.Lock() + // activeSrv is swapped atomically during hot reload. + var activeMu sync.RWMutex + activeSrv := initialSrv + + httpServer := &http.Server{ + Addr: listenAddr, + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + activeMu.RLock() + srv := activeSrv + activeMu.RUnlock() + srv.ServeHTTP(w, r) + }), + } + + // reload guards against overlapping reloads triggered by concurrent signals + // or file-watcher callbacks. + var reloading bool + var reloadMu sync.Mutex + + reload := func() { + reloadMu.Lock() if reloading { - reloadMutex.Unlock() + reloadMu.Unlock() return } reloading = true - reloadMutex.Unlock() + reloadMu.Unlock() defer func() { - reloadMutex.Lock() + reloadMu.Lock() reloading = false - reloadMutex.Unlock() + reloadMu.Unlock() }() - if currentPM, ok := srv.Handler.(*proxy.ProxyManager); ok { - mainLogger.Info("Reloading Configuration") - conf, err = config.LoadConfig(*configPath) - if err != nil { - mainLogger.Warnf("Unable to reload configuration: %v", err) - return - } + proxyLog.Info("reloading configuration") - mainLogger.Debug("Configuration Changed") - currentPM.Shutdown() - if mon != nil { - mon.UpdateConfig(conf.Performance) - } - newPM := proxy.New(conf) - newPM.SetVersion(date, commit, version) - newPM.SetPerfMonitor(mon) - srv.Handler = newPM - mainLogger.Debug("Configuration Reloaded") - - // wait a few seconds and tell any UI to reload - time.AfterFunc(3*time.Second, func() { - event.Emit(proxy.ConfigFileChangedEvent{ - ReloadingState: proxy.ReloadingStateEnd, - }) - }) - } else { - conf, err = config.LoadConfig(*configPath) - if err != nil { - mainLogger.Errorf("Unable to load configuration: %v", err) - os.Exit(1) - } - newPM := proxy.New(conf) - newPM.SetVersion(date, commit, version) - newPM.SetPerfMonitor(mon) - srv.Handler = newPM + newCfg, err := config.LoadConfig(configPath) + if err != nil { + proxyLog.Warnf("failed to reload config: %v", err) + return } + + if len(newCfg.Profiles) > 0 { + proxyLog.Warn("Profile functionality has been removed in favor of Groups. See the README for more information.") + } + + if perfMon != nil { + perfMon.UpdateConfig(newCfg.Performance) + } + + newSrv, err := server.New(newCfg, muxLog, proxyLog, upstreamLog, perfMon, buildInfo) + if err != nil { + proxyLog.Warnf("failed to build new server during reload: %v", err) + return + } + + activeMu.Lock() + old := activeSrv + activeSrv = newSrv + activeMu.Unlock() + + applyLogSettings(newCfg) + + if err := old.Shutdown(shutdownTimeout); err != nil { + proxyLog.Warnf("error shutting down old server during reload: %v", err) + } + + // Notify UI after a short delay so it can refresh model state. + time.AfterFunc(3*time.Second, func() { + event.Emit(shared.ConfigFileChangedEvent{State: shared.ReloadingStateEnd}) + }) + + proxyLog.Info("configuration reloaded") } - // load the initial proxy manager - reloadProxyManager() + watcherCtx, watcherCancel := context.WithCancel(context.Background()) + defer watcherCancel() - if *watchConfig { + if *flagWatchConfig { + absConfigPath, err := filepath.Abs(configPath) + if err != nil { + slog.Error("watch-config: failed to resolve config path", "error", err) + os.Exit(1) + } + proxyLog.Info("watching configuration for changes (poll-based, 2s interval)") go func() { - absConfigPath, err := filepath.Abs(*configPath) - if err != nil { - mainLogger.Errorf("watch-config unable to determine absolute path for watching config file: %v", err) - return - } - mainLogger.Info("Watching configuration for changes (poll-based, 2s interval)") (&configwatcher.Watcher{ Path: absConfigPath, Interval: configwatcher.DefaultInterval, - OnChange: func() { - reloadProxyManager() - }, + OnChange: reload, }).Run(watcherCtx) }() } - // Signal handling + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP) + + go func() { + var startErr error + if useTLS { + proxyLog.Infof("llama-swap listening with TLS on https://%s", listenAddr) + startErr = httpServer.ListenAndServeTLS(*flagCertFile, *flagKeyFile) + } else { + proxyLog.Infof("llama-swap listening on http://%s", listenAddr) + startErr = httpServer.ListenAndServe() + } + if startErr != nil && !errors.Is(startErr, http.ErrServerClosed) { + slog.Error("http server error", "error", startErr) + os.Exit(1) + } + }() + + exitChan := make(chan struct{}) + go func() { for { sig := <-sigChan switch sig { case syscall.SIGHUP: - mainLogger.Debug("Received SIGHUP") - reloadProxyManager() + proxyLog.Info("received SIGHUP, reloading config") + go reload() case syscall.SIGINT, syscall.SIGTERM: - mainLogger.Debugf("Received signal %v, shutting down...", sig) - if mon != nil { - mon.Stop() - } + proxyLog.Infof("received signal %v, shutting down", sig) watcherCancel() - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + + activeMu.RLock() + srv := activeSrv + activeMu.RUnlock() + + // Close long-lived SSE streams first so httpServer.Shutdown can + // drain without blocking on them for the full timeout. + srv.CloseStreams() + + shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) defer cancel() - - if pm, ok := srv.Handler.(*proxy.ProxyManager); ok { - pm.Shutdown() - } else { - mainLogger.Error("srv.Handler is not of type *proxy.ProxyManager") + if err := httpServer.Shutdown(shutdownCtx); err != nil { + proxyLog.Warnf("http server shutdown error: %v", err) } - if err := srv.Shutdown(ctx); err != nil { - mainLogger.Errorf("Server shutdown: %v", err) + if err := srv.Shutdown(shutdownTimeout); err != nil { + proxyLog.Warnf("router shutdown error: %v", err) } + + if perfMon != nil { + perfMon.Stop() + } + close(exitChan) return - default: - // do nothing on other signals } } }() - // Start server - go func() { - var err error - if useTLS { - mainLogger.Infof("llama-swap listening with TLS on https://%s", *listenStr) - err = srv.ListenAndServeTLS(*certFile, *keyFile) - } else { - mainLogger.Infof("llama-swap listening on http://%s", *listenStr) - err = srv.ListenAndServe() - } - if err != nil && err != http.ErrServerClosed { - mainLogger.Errorf("Fatal server error: %v", err) - os.Exit(1) - } - }() - - // Wait for exit signal <-exitChan + proxyLog.Info("shutdown complete") } diff --git a/proxy/helpers_test.go b/proxy/helpers_test.go index 256f8ce..185cb5b 100644 --- a/proxy/helpers_test.go +++ b/proxy/helpers_test.go @@ -15,8 +15,8 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/logmon" - "github.com/mostlygeek/llama-swap/proxy/config" "github.com/stretchr/testify/require" "github.com/tidwall/gjson" "gopkg.in/yaml.v3" diff --git a/proxy/matrix.go b/proxy/matrix.go index feb0f12..f699436 100644 --- a/proxy/matrix.go +++ b/proxy/matrix.go @@ -7,8 +7,8 @@ import ( "sort" "sync" + "github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/logmon" - "github.com/mostlygeek/llama-swap/proxy/config" ) // MatrixSolver contains pure swap-decision logic with no Process dependencies. diff --git a/proxy/matrix_test.go b/proxy/matrix_test.go index 8b92137..01c3141 100644 --- a/proxy/matrix_test.go +++ b/proxy/matrix_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/mostlygeek/llama-swap/proxy/config" + "github.com/mostlygeek/llama-swap/internal/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/proxy/metrics_monitor.go b/proxy/metrics_monitor.go index 1a99ef2..5cfa7b7 100644 --- a/proxy/metrics_monitor.go +++ b/proxy/metrics_monitor.go @@ -15,10 +15,10 @@ import ( "github.com/fxamacker/cbor/v2" "github.com/gin-gonic/gin" "github.com/klauspost/compress/zstd" - "github.com/mostlygeek/llama-swap/event" + "github.com/mostlygeek/llama-swap/internal/cache" + "github.com/mostlygeek/llama-swap/internal/event" "github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/ring" - "github.com/mostlygeek/llama-swap/proxy/cache" "github.com/tidwall/gjson" ) diff --git a/proxy/metrics_monitor_test.go b/proxy/metrics_monitor_test.go index a569922..d589b32 100644 --- a/proxy/metrics_monitor_test.go +++ b/proxy/metrics_monitor_test.go @@ -14,8 +14,8 @@ import ( "github.com/fxamacker/cbor/v2" "github.com/gin-gonic/gin" - "github.com/mostlygeek/llama-swap/event" - "github.com/mostlygeek/llama-swap/proxy/cache" + "github.com/mostlygeek/llama-swap/internal/cache" + "github.com/mostlygeek/llama-swap/internal/event" "github.com/stretchr/testify/assert" "github.com/tidwall/gjson" ) diff --git a/proxy/peerproxy.go b/proxy/peerproxy.go index 5350fa3..98e2ba1 100644 --- a/proxy/peerproxy.go +++ b/proxy/peerproxy.go @@ -10,8 +10,8 @@ import ( "strings" "time" + "github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/logmon" - "github.com/mostlygeek/llama-swap/proxy/config" ) type peerProxyMember struct { diff --git a/proxy/peerproxy_test.go b/proxy/peerproxy_test.go index dd69471..1837c6e 100644 --- a/proxy/peerproxy_test.go +++ b/proxy/peerproxy_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/mostlygeek/llama-swap/proxy/config" + "github.com/mostlygeek/llama-swap/internal/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/proxy/process.go b/proxy/process.go index 5c92290..e1117a8 100644 --- a/proxy/process.go +++ b/proxy/process.go @@ -17,9 +17,9 @@ import ( "syscall" "time" - "github.com/mostlygeek/llama-swap/event" + "github.com/mostlygeek/llama-swap/internal/config" + "github.com/mostlygeek/llama-swap/internal/event" "github.com/mostlygeek/llama-swap/internal/logmon" - "github.com/mostlygeek/llama-swap/proxy/config" ) type ProcessState string diff --git a/proxy/process_test.go b/proxy/process_test.go index 192f2ec..d6083c8 100644 --- a/proxy/process_test.go +++ b/proxy/process_test.go @@ -11,8 +11,8 @@ import ( "testing" "time" + "github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/logmon" - "github.com/mostlygeek/llama-swap/proxy/config" "github.com/stretchr/testify/assert" ) diff --git a/proxy/processgroup.go b/proxy/processgroup.go index aaa24d9..4ceb9db 100644 --- a/proxy/processgroup.go +++ b/proxy/processgroup.go @@ -6,8 +6,8 @@ import ( "slices" "sync" + "github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/logmon" - "github.com/mostlygeek/llama-swap/proxy/config" ) type ProcessGroup struct { diff --git a/proxy/processgroup_test.go b/proxy/processgroup_test.go index d261bae..e1284a9 100644 --- a/proxy/processgroup_test.go +++ b/proxy/processgroup_test.go @@ -9,7 +9,7 @@ import ( "testing" "time" - "github.com/mostlygeek/llama-swap/proxy/config" + "github.com/mostlygeek/llama-swap/internal/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index 0b5b857..a06ce5f 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -16,10 +16,10 @@ import ( "time" "github.com/gin-gonic/gin" - "github.com/mostlygeek/llama-swap/event" + "github.com/mostlygeek/llama-swap/internal/config" + "github.com/mostlygeek/llama-swap/internal/event" "github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/perf" - "github.com/mostlygeek/llama-swap/proxy/config" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) diff --git a/proxy/proxymanager_api.go b/proxy/proxymanager_api.go index 8c348cd..b3f8437 100644 --- a/proxy/proxymanager_api.go +++ b/proxy/proxymanager_api.go @@ -11,7 +11,7 @@ import ( "time" "github.com/gin-gonic/gin" - "github.com/mostlygeek/llama-swap/event" + "github.com/mostlygeek/llama-swap/internal/event" "github.com/mostlygeek/llama-swap/internal/perf" ) diff --git a/proxy/proxymanager_loghandlers_test.go b/proxy/proxymanager_loghandlers_test.go index 1b9ba5b..4e3af50 100644 --- a/proxy/proxymanager_loghandlers_test.go +++ b/proxy/proxymanager_loghandlers_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/mostlygeek/llama-swap/proxy/config" + "github.com/mostlygeek/llama-swap/internal/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/proxy/proxymanager_test.go b/proxy/proxymanager_test.go index b93517a..f637eba 100644 --- a/proxy/proxymanager_test.go +++ b/proxy/proxymanager_test.go @@ -16,8 +16,8 @@ import ( "testing" "time" - "github.com/mostlygeek/llama-swap/event" - "github.com/mostlygeek/llama-swap/proxy/config" + "github.com/mostlygeek/llama-swap/internal/config" + "github.com/mostlygeek/llama-swap/internal/event" "github.com/stretchr/testify/assert" "github.com/tidwall/gjson" )