From f6cf9f5844ab66bcbde3bdd76b4da32543336602 Mon Sep 17 00:00:00 2001 From: Benson Wong Date: Thu, 16 Apr 2026 22:47:42 -0700 Subject: [PATCH] proxy: Refactor tests (#660) - use YAML for test configurations - remove most uses of simple-responder, opting to use process.testHandler Fixes #655 --- ai-plans/improve-tests-655.md | 183 ++++++++++ proxy/helpers_test.go | 202 +++++++++++ proxy/process.go | 51 +++ proxy/proxymanager_test.go | 661 +++++++++++++++++----------------- 4 files changed, 758 insertions(+), 339 deletions(-) create mode 100644 ai-plans/improve-tests-655.md diff --git a/ai-plans/improve-tests-655.md b/ai-plans/improve-tests-655.md new file mode 100644 index 0000000..f48398a --- /dev/null +++ b/ai-plans/improve-tests-655.md @@ -0,0 +1,183 @@ +# Improve Testability (#655) + +## Current Pain Points + +1. **Tests bypass config loading** - ~80% of tests build `config.Config` structs directly, skipping YAML parsing, env var substitution, macro expansion, and `${PORT}` assignment. Config bugs in those paths go untested. + +2. **simple-responder is everywhere** - Every proxy/routing test launches a real subprocess, waits for health checks (~healthCheckTimeout: 15), and manages process lifecycle just to test HTTP routing. Most of that overhead is wasted. + +3. **Port counter is fragile** - A global `nextTestPort` counter starting at 12000 with a mutex. Parallel tests or leftover processes can collide. + +## Stages + +### Stage 1: YAML-based test config helper + +**Goal:** Tests go through the real `LoadConfigFromReader` path instead of hand-building structs. + +**Effort:** Low | **Impact:** Config bugs caught earlier | **Risk:** None + +Create a test helper in `proxy/helpers_test.go`: + +```go +// testConfigFromYAML substitutes simple-responder paths and loads through +// the real config pipeline (env vars, macros, port assignment, etc.) +func testConfigFromYAML(t *testing.T, yamlTmpl string) config.Config { + t.Helper() + yamlStr := strings.ReplaceAll(yamlTmpl, "{{RESPONDER}}", filepath.ToSlash(simpleResponderPath)) + cfg, err := config.LoadConfigFromReader(strings.NewReader(yamlStr)) + require.NoError(t, err) + return cfg +} +``` + +Tests would then look like: + +```go +func TestProxyManager_SwapProcessCorrectly(t *testing.T) { + config := testConfigFromYAML(t, ` +healthCheckTimeout: 15 +logLevel: error +models: + model1: + cmd: {{RESPONDER}} --port ${PORT} -silent -respond model1 + model2: + cmd: {{RESPONDER}} --port ${PORT} -silent -respond model2 +`) + proxy := New(config) + // ... same assertions +} +``` + +**Why this stage first:** Zero production code changes. Pure test-side refactoring. Can be done incrementally - migrate tests one at a time. Each migrated test now validates the full config pipeline. + +**Scope:** ~20-30 tests in `proxymanager_test.go`, `processgroup_test.go`, `peerproxy_test.go`. + +### Stage 2: Injected test handler (eliminate simple-responder for routing tests) + +**Goal:** Replace simple-responder subprocess launches with an injected `http.Handler` for tests that don't specifically test process lifecycle. + +**Effort:** Medium | **Impact:** 10-100x faster routing tests | **Risk:** Low (additive, no existing code broken) + +Add a `testHandler http.Handler` field to `Process`. When set, `ProxyRequest` delegates directly to this handler instead of going through the reverse proxy. No subprocess, no health checks, no TCP roundtrip. + +**2a. Add testHandler to Process:** + +```go +// In Process struct (process.go): +testHandler http.Handler // set only in tests; bypasses subprocess and reverse proxy +``` + +In `Process.Start()`, skip subprocess + health check when handler is set: + +```go +func (p *Process) start() error { + if p.testHandler != nil { + p.setState(StateReady) + return nil + } + // existing subprocess logic... +} +``` + +In `Process.ProxyRequest()`, delegate directly to the handler: + +```go +// Before the reverseProxy.ServeHTTP call: +if p.testHandler != nil { + p.testHandler.ServeHTTP(w, r) + return +} +``` + +**2b. Test helper to create the handler:** + +```go +// newTestHandler returns an http.Handler that mimics llama.cpp's API +// (same endpoints as simple-responder). +func newTestHandler(respond string) http.Handler { + mux := http.NewServeMux() + mux.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { ... }) + mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { ... }) + // ... other endpoints + return mux +} +``` + +Tests for routing/auth/CORS/streaming then become: + +```go +func TestProxyManager_AuthRequired(t *testing.T) { + handler := newTestHandler("model1") + + config := testConfigFromYAML(t, ` +healthCheckTimeout: 15 +logLevel: error +requiredAPIKeys: [test-key] +models: + model1: + cmd: {{RESPONDER}} --port ${PORT} -silent -respond model1 +`) + pm := NewProxyManager(config) + // inject handler — skips subprocess, health check, port allocation + pm.processGroups["model1"].process.testHandler = handler +} +``` + +**Why this matters:** The handler is called directly in-process. No subprocess spawn, no health check timeout, no port allocation, no TCP roundtrip, no reverse proxy overhead. Routing tests go from ~100ms each (process startup + health check) to ~1ms. Unlike an `httptest.Server` approach, there are zero network hops. + +**Why not blank-cmd + proxy URL:** A blank `cmd` with a `proxy` field pointing at `httptest.Server` still requires a real TCP roundtrip through the reverse proxy and introduces "external process" semantics to the config schema. Injecting the handler directly keeps it purely a test concern with no config changes. + +**Scope:** Most tests in `proxymanager_test.go` (auth, CORS, model listing, streaming, peer proxy), `peerproxy_test.go`, `metrics_monitor_test.go`. + +### Stage 3: Migrate tests incrementally + +**Goal:** Convert existing tests to use the Stage 1 + Stage 2 helpers. + +**Effort:** Medium | **Impact:** Cleaner, more reliable tests | **Risk:** None + +Priority order: +1. `proxymanager_test.go` routing tests (highest count, most repetition) +2. `peerproxy_test.go` (straightforward, all HTTP routing) +3. `metrics_monitor_test.go` (capture logic doesn't need real processes) +4. `processgroup_test.go` swap tests (keep simple-responder for actual swap lifecycle tests) + +Tests that **must keep simple-responder:** +- Process lifecycle: start/stop, SIGKILL, SIGTERM, TTL expiry, health check failures, failed start counting +- ProcessGroup swap concurrency (the port-collision test in `TestProcessGroup_ProxyRequestSwapIsTrueParallel`) + +**Scope:** ~60-70% of tests can drop simple-responder. + +### Stage 4 (optional): Process interface for ProcessGroup + +**Goal:** Enable pure unit tests of ProcessGroup's swap/exclusive/concurrency logic without any HTTP server at all. + +**Effort:** High | **Impact:** Pure unit tests possible | **Risk:** Medium (refactor core code) + +```go +type ProcessController interface { + Start() error + Stop(StopStrategy) + ProxyRequest(http.ResponseWriter, *http.Request) error + CurrentState() ProcessState + ID() string + SetState(ProcessState) // for test setup +} +``` + +This requires: +- Extracting the interface +- A `MockProcess` implementation +- Refactoring `ProcessGroup` to use the interface instead of `*Process` + +**Recommendation:** Only do this if ProcessGroup grows significantly more complex. Stages 1-3 give 80% of the benefit for 20% of the effort. + +## Effort/Impact Summary + +| Stage | Effort | Impact | Risk | +|-------|--------|--------|------| +| 1. YAML config helper | Low | Config bugs caught earlier | None | +| 2. Injected test handler | Medium | 10-100x faster routing tests | Low | +| 3. Migrate tests | Medium | Cleaner, more reliable tests | None | +| 4. Process interface | High | Pure unit tests possible | Medium | + +**Recommended approach:** Do stages 1-3 in order. Each stage is independently valuable and can ship on its own. Stage 4 is deferred unless there's a specific need. diff --git a/proxy/helpers_test.go b/proxy/helpers_test.go index c2c4702..ccd4d8f 100644 --- a/proxy/helpers_test.go +++ b/proxy/helpers_test.go @@ -1,15 +1,22 @@ package proxy import ( + "encoding/json" "fmt" + "io" + "net/http" "os" "path/filepath" "runtime" + "strings" "sync" "testing" + "time" "github.com/gin-gonic/gin" "github.com/mostlygeek/llama-swap/proxy/config" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" "gopkg.in/yaml.v3" ) @@ -66,6 +73,16 @@ func getTestPort() int { return port } +// testConfigFromYAML substitutes {{RESPONDER}} with the simple-responder path and +// loads through the real config pipeline (env vars, macros, port assignment, etc.) +func testConfigFromYAML(t *testing.T, yamlTmpl string) config.Config { + t.Helper() + yamlStr := strings.ReplaceAll(yamlTmpl, "{{RESPONDER}}", filepath.ToSlash(simpleResponderPath)) + cfg, err := config.LoadConfigFromReader(strings.NewReader(yamlStr)) + require.NoError(t, err) + return cfg +} + func getTestSimpleResponderConfig(expectedMessage string) config.ModelConfig { return getTestSimpleResponderConfigPort(expectedMessage, getTestPort()) } @@ -88,3 +105,188 @@ proxy: "http://127.0.0.1:%d" return cfg } + +// injectTestHandlers sets a testHandler on every Process in every ProcessGroup +// of the given ProxyManager, bypassing subprocess launches. modelResponses maps +// model IDs to their respond strings; if a model ID is not in the map, the model +// ID itself is used. +func injectTestHandlers(pm *ProxyManager, modelResponses map[string]string) { + for _, pg := range pm.processGroups { + for modelID, process := range pg.processes { + respond := modelID + if r, ok := modelResponses[modelID]; ok { + respond = r + } + process.testHandler = newTestHandler(respond) + } + } +} + +// newTestHandler returns an http.Handler that mimics simple-responder's API. +// It supports the endpoints that routing tests depend on, without launching +// any subprocess or binding any port. +func newTestHandler(respond string) http.Handler { + mux := http.NewServeMux() + + mux.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + bodyBytes, _ := io.ReadAll(r.Body) + isStreaming := r.URL.Query().Get("stream") == "true" + + if wait := r.URL.Query().Get("wait"); wait != "" { + if d, err := time.ParseDuration(wait); err == nil { + time.Sleep(d) + } + } + + if isStreaming { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + flusher := w.(http.Flusher) + + for i := 0; i < 10; i++ { + data, _ := json.Marshal(map[string]any{ + "created": time.Now().Unix(), + "choices": []map[string]any{ + {"index": 0, "delta": map[string]any{"content": "asdf"}, "finish_reason": nil}, + }, + }) + fmt.Fprintf(w, "event: message\ndata: %s\n\n", data) + flusher.Flush() + } + + finalData, _ := json.Marshal(map[string]any{ + "usage": map[string]any{ + "completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35, + }, + "timings": map[string]any{ + "prompt_n": 25, "prompt_ms": 13, "predicted_n": 10, + "predicted_ms": 17, "predicted_per_second": 10, + }, + }) + fmt.Fprintf(w, "event: message\ndata: %s\n\n", finalData) + flusher.Flush() + + fmt.Fprintf(w, "event: message\ndata: [DONE]\n\n") + flusher.Flush() + } else { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{ + "responseMessage": respond, + "h_content_length": r.Header.Get("Content-Length"), + "request_body": string(bodyBytes), + "usage": map[string]any{ + "completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35, + }, + "timings": map[string]any{ + "prompt_n": 25, "prompt_ms": 13, "predicted_n": 10, + "predicted_ms": 17, "predicted_per_second": 10, + }, + }) + } + }) + + mux.HandleFunc("/v1/audio/speech", func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + modelName := gjson.GetBytes(body, "model").String() + if modelName != respond { + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(map[string]string{"error": fmt.Sprintf("Invalid model: %s, expected: %s", modelName, respond)}) + return + } + json.NewEncoder(w).Encode(map[string]string{"message": "ok"}) + }) + + mux.HandleFunc("/v1/completions", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{ + "responseMessage": respond, + "usage": map[string]any{ + "completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35, + }, + }) + }) + + mux.HandleFunc("/completion", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{ + "responseMessage": respond, + "usage": map[string]any{ + "completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35, + }, + }) + }) + + mux.HandleFunc("/v1/audio/transcriptions", func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseMultipartForm(10 << 20); err != nil { + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(map[string]string{"error": fmt.Sprintf("Error parsing multipart form: %s", err)}) + return + } + model := r.FormValue("model") + if model == "" { + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(map[string]string{"error": "Missing model parameter"}) + return + } + file, _, err := r.FormFile("file") + if err != nil { + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(map[string]string{"error": fmt.Sprintf("Error getting file: %s", err)}) + return + } + fileBytes, _ := io.ReadAll(file) + file.Close() + json.NewEncoder(w).Encode(map[string]any{ + "text": fmt.Sprintf("The length of the file is %d bytes", len(fileBytes)), + "model": model, + "h_content_type": r.Header.Get("Content-Type"), + "h_content_length": r.Header.Get("Content-Length"), + }) + }) + + mux.HandleFunc("/v1/audio/voices", func(w http.ResponseWriter, r *http.Request) { + model := r.URL.Query().Get("model") + json.NewEncoder(w).Encode(map[string]any{ + "voices": []string{"voice1"}, "model": model, + }) + }) + + mux.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + fmt.Fprint(w, respond) + }) + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/" { + http.NotFound(w, r) + return + } + w.Header().Set("Content-Type", "text/plain") + fmt.Fprintf(w, "%s %s", r.Method, r.URL.Path) + }) + + mux.HandleFunc("/sdapi/v1/txt2img", func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + modelName := gjson.GetBytes(body, "model").String() + json.NewEncoder(w).Encode(map[string]any{ + "model": modelName, "images": []string{}, + }) + }) + + mux.HandleFunc("/sdapi/v1/img2img", func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + modelName := gjson.GetBytes(body, "model").String() + json.NewEncoder(w).Encode(map[string]any{ + "model": modelName, "images": []string{}, + }) + }) + + mux.HandleFunc("/sdapi/v1/loras", func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]any{ + "loras": []string{}, + }) + }) + + return mux +} diff --git a/proxy/process.go b/proxy/process.go index dc05106..1025d18 100644 --- a/proxy/process.go +++ b/proxy/process.go @@ -77,6 +77,9 @@ type Process struct { // used for testing to override the default value gracefulStopTimeout time.Duration + // used for testing to bypass subprocess and reverse proxy + testHandler http.Handler + // track the number of failed starts failedStartCount int } @@ -236,6 +239,49 @@ func (p *Process) forceState(newState ProcessState) { // at any time. func (p *Process) start() error { + // test-only fast path: skip subprocess, health check, and TTL goroutine + if p.testHandler != nil { + if curState, err := p.swapState(StateStopped, StateStarting); err != nil { + if err == ErrExpectedStateMismatch { + if curState == StateStarting { + p.waitStarting.Wait() + curState = p.CurrentState() + if curState == StateReady { + return nil + } + return fmt.Errorf("process was already starting but wound up in state %v", curState) + } + return fmt.Errorf("process was in state %v when start() was called", curState) + } + return fmt.Errorf("failed to set Process state to starting: current state: %v, error: %v", curState, err) + } + defer p.waitStarting.Done() + + // Mimic the real stop path: cancelUpstream transitions + // StateStopping -> StateStopped and closes cmdWaitChan, + // matching what waitForCmd does for real subprocesses. + ch := make(chan struct{}) + p.cmdMutex.Lock() + p.cancelUpstream = func() { + if curState := p.CurrentState(); curState == StateStopping { + if _, err := p.swapState(StateStopping, StateStopped); err != nil { + p.forceState(StateStopped) + } + } else { + p.forceState(StateStopped) + } + close(ch) + } + p.cmdWaitChan = ch + p.cmdMutex.Unlock() + + if curState, err := p.swapState(StateStarting, StateReady); err != nil { + return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err) + } + p.failedStartCount = 0 + return nil + } + if p.config.Proxy == "" { return fmt.Errorf("can not start(), upstream proxy missing") } @@ -577,6 +623,11 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) { if !srw.waitForCompletion(completionTimeout) { p.proxyLogger.Warnf("<%s> status updates goroutine did not complete within %v, proceeding with proxy request", p.ID, completionTimeout) } + } + + if p.testHandler != nil { + p.testHandler.ServeHTTP(w, r) + } else if srw != nil { p.reverseProxy.ServeHTTP(srw, r) } else { p.reverseProxy.ServeHTTP(w, r) diff --git a/proxy/proxymanager_test.go b/proxy/proxymanager_test.go index dc59e0d..a5dbfbc 100644 --- a/proxy/proxymanager_test.go +++ b/proxy/proxymanager_test.go @@ -45,16 +45,17 @@ func CreateTestResponseRecorder() *TestResponseRecorder { } func TestProxyManager_SwapProcessCorrectly(t *testing.T) { - config := config.AddDefaultGroupToConfig(config.Config{ - HealthCheckTimeout: 15, - Models: map[string]config.ModelConfig{ - "model1": getTestSimpleResponderConfig("model1"), - "model2": getTestSimpleResponderConfig("model2"), - }, - LogLevel: "error", - }) + cfg := testConfigFromYAML(t, ` +healthCheckTimeout: 15 +logLevel: error +models: + model1: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 + model2: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond model2 +`) - proxy := New(config) + proxy := New(cfg) defer proxy.StopProcesses(StopWaitForInflightRequest) for _, modelName := range []string{"model1", "model2"} { @@ -68,28 +69,28 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) { } } func TestProxyManager_SwapMultiProcess(t *testing.T) { - config := config.AddDefaultGroupToConfig(config.Config{ - HealthCheckTimeout: 15, - Models: map[string]config.ModelConfig{ - "model1": getTestSimpleResponderConfig("model1"), - "model2": getTestSimpleResponderConfig("model2"), - }, - LogLevel: "error", - Groups: map[string]config.GroupConfig{ - "G1": { - Swap: true, - Exclusive: false, - Members: []string{"model1"}, - }, - "G2": { - Swap: true, - Exclusive: false, - Members: []string{"model2"}, - }, - }, - }) + cfg := testConfigFromYAML(t, ` +healthCheckTimeout: 15 +logLevel: error +models: + model1: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 + model2: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond model2 +groups: + G1: + swap: true + exclusive: false + members: + - model1 + G2: + swap: true + exclusive: false + members: + - model2 +`) - proxy := New(config) + proxy := New(cfg) defer proxy.StopProcesses(StopWaitForInflightRequest) tests := []string{"model1", "model2"} @@ -113,25 +114,24 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) { // Test that a persistent group is not affected by the swapping behaviour of // other groups. func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) { - config := config.AddDefaultGroupToConfig(config.Config{ - HealthCheckTimeout: 15, - Models: map[string]config.ModelConfig{ - "model1": getTestSimpleResponderConfig("model1"), // goes into the default group - "model2": getTestSimpleResponderConfig("model2"), - }, - LogLevel: "error", - Groups: map[string]config.GroupConfig{ - // the forever group is persistent and should not be affected by model1 - "forever": { - Swap: true, - Exclusive: false, - Persistent: true, - Members: []string{"model2"}, - }, - }, - }) + cfg := testConfigFromYAML(t, ` +healthCheckTimeout: 15 +logLevel: error +models: + model1: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 + model2: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond model2 +groups: + forever: + swap: true + exclusive: false + persistent: true + members: + - model2 +`) - proxy := New(config) + proxy := New(cfg) defer proxy.StopProcesses(StopWaitForInflightRequest) // make requests to load all models, loading model1 should not affect model2 @@ -157,17 +157,19 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) { t.Skip("skipping slow test") } - config := config.AddDefaultGroupToConfig(config.Config{ - HealthCheckTimeout: 15, - Models: map[string]config.ModelConfig{ - "model1": getTestSimpleResponderConfig("model1"), - "model2": getTestSimpleResponderConfig("model2"), - "model3": getTestSimpleResponderConfig("model3"), - }, - LogLevel: "error", - }) + cfg := testConfigFromYAML(t, ` +healthCheckTimeout: 15 +logLevel: error +models: + model1: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 + model2: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond model2 + model3: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond model3 +`) - proxy := New(config) + proxy := New(cfg) defer proxy.StopProcesses(StopWaitForInflightRequest) results := map[string]string{} @@ -175,7 +177,7 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) { var wg sync.WaitGroup var mu sync.Mutex - for key := range config.Models { + for key := range cfg.Models { wg.Add(1) go func(key string) { defer wg.Done() @@ -203,7 +205,7 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) { } wg.Wait() - assert.Len(t, results, len(config.Models)) + assert.Len(t, results, len(cfg.Models)) for key, result := range results { assert.Equal(t, key, result) @@ -212,29 +214,27 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) { func TestProxyManager_ListModelsHandler(t *testing.T) { - model1Config := getTestSimpleResponderConfig("model1") - model1Config.Name = "Model 1" - model1Config.Description = "Model 1 description is used for testing" - - model2Config := getTestSimpleResponderConfig("model2") - model2Config.Name = " " // empty whitespace only strings will get ignored - model2Config.Description = " " - - cfg := config.Config{ - HealthCheckTimeout: 15, - Models: map[string]config.ModelConfig{ - "model1": model1Config, - "model2": model2Config, - "model3": getTestSimpleResponderConfig("model3"), - }, - Peers: map[string]config.PeerConfig{ - "peer1": { - Proxy: "http://peer1:8080", - Models: []string{"peer-model-a", "peer-model-b"}, - }, - }, - LogLevel: "error", - } + cfg := testConfigFromYAML(t, ` +healthCheckTimeout: 15 +logLevel: error +models: + model1: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 + name: "Model 1" + description: "Model 1 description is used for testing" + model2: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond model2 + name: " " + description: " " + model3: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond model3 +peers: + peer1: + proxy: http://peer1:8080 + models: + - peer-model-a + - peer-model-b +`) proxy := New(cfg) @@ -412,22 +412,22 @@ models: func TestProxyManager_ListModelsHandler_SortedByID(t *testing.T) { // Intentionally add models in non-sorted order and with an unlisted model - config := config.Config{ - HealthCheckTimeout: 15, - Models: map[string]config.ModelConfig{ - "zeta": getTestSimpleResponderConfig("zeta"), - "alpha": getTestSimpleResponderConfig("alpha"), - "beta": getTestSimpleResponderConfig("beta"), - "hidden": func() config.ModelConfig { - mc := getTestSimpleResponderConfig("hidden") - mc.Unlisted = true - return mc - }(), - }, - LogLevel: "error", - } + cfg := testConfigFromYAML(t, ` +healthCheckTimeout: 15 +logLevel: error +models: + zeta: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond zeta + alpha: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond alpha + beta: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond beta + hidden: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond hidden + unlisted: true +`) - proxy := New(config) + proxy := New(cfg) // Request models list req := httptest.NewRequest("GET", "/v1/models", nil) @@ -457,21 +457,19 @@ func TestProxyManager_ListModelsHandler_SortedByID(t *testing.T) { func TestProxyManager_ListModelsHandler_IncludeAliasesInList(t *testing.T) { // Configure alias - config := config.Config{ - HealthCheckTimeout: 15, - IncludeAliasesInList: true, - Models: map[string]config.ModelConfig{ - "model1": func() config.ModelConfig { - mc := getTestSimpleResponderConfig("model1") - mc.Name = "Model 1" - mc.Aliases = []string{"alias1"} - return mc - }(), - }, - LogLevel: "error", - } + cfg := testConfigFromYAML(t, ` +healthCheckTimeout: 15 +logLevel: error +includeAliasesInList: true +models: + model1: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 + name: "Model 1" + aliases: + - alias1 +`) - proxy := New(config) + proxy := New(cfg) // Request models list req := httptest.NewRequest("GET", "/v1/models", nil) @@ -534,7 +532,7 @@ func TestProxyManager_Shutdown(t *testing.T) { model3Config := getTestSimpleResponderConfigPort("model3", 9993) model3Config.Proxy = "http://localhost:10003/" - config := config.AddDefaultGroupToConfig(config.Config{ + cfg := config.AddDefaultGroupToConfig(config.Config{ HealthCheckTimeout: 15, Models: map[string]config.ModelConfig{ "model1": model1Config, @@ -550,7 +548,7 @@ func TestProxyManager_Shutdown(t *testing.T) { }, }) - proxy := New(config) + proxy := New(cfg) // Start all the processes var wg sync.WaitGroup @@ -577,13 +575,13 @@ func TestProxyManager_Shutdown(t *testing.T) { } func TestProxyManager_Unload(t *testing.T) { - conf := config.AddDefaultGroupToConfig(config.Config{ - HealthCheckTimeout: 15, - Models: map[string]config.ModelConfig{ - "model1": getTestSimpleResponderConfig("model1"), - }, - LogLevel: "error", - }) + conf := testConfigFromYAML(t, ` +healthCheckTimeout: 15 +logLevel: error +models: + model1: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 +`) proxy := New(conf) reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1") @@ -609,22 +607,23 @@ func TestProxyManager_Unload(t *testing.T) { func TestProxyManager_UnloadSingleModel(t *testing.T) { const testGroupId = "testGroup" - config := config.AddDefaultGroupToConfig(config.Config{ - HealthCheckTimeout: 15, - Models: map[string]config.ModelConfig{ - "model1": getTestSimpleResponderConfig("model1"), - "model2": getTestSimpleResponderConfig("model2"), - }, - Groups: map[string]config.GroupConfig{ - testGroupId: { - Swap: false, - Members: []string{"model1", "model2"}, - }, - }, - LogLevel: "error", - }) + cfg := testConfigFromYAML(t, ` +healthCheckTimeout: 15 +logLevel: error +models: + model1: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 + model2: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond model2 +groups: + testGroup: + swap: false + members: + - model1 + - model2 +`) - proxy := New(config) + proxy := New(cfg) defer proxy.StopProcesses(StopImmediately) // start both model @@ -660,14 +659,15 @@ func TestProxyManager_UnloadSingleModel(t *testing.T) { // Test issue #61 `Listing the current list of models and the loaded model.` func TestProxyManager_RunningEndpoint(t *testing.T) { // Shared configuration - config := config.AddDefaultGroupToConfig(config.Config{ - HealthCheckTimeout: 15, - Models: map[string]config.ModelConfig{ - "model1": getTestSimpleResponderConfig("model1"), - "model2": getTestSimpleResponderConfig("model2"), - }, - LogLevel: "warn", - }) + cfg := testConfigFromYAML(t, ` +healthCheckTimeout: 15 +logLevel: warn +models: + model1: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 + model2: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond model2 +`) // Define a helper struct to parse the JSON response. type RunningResponse struct { @@ -683,8 +683,9 @@ func TestProxyManager_RunningEndpoint(t *testing.T) { } // Create proxy once for all tests - proxy := New(config) + proxy := New(cfg) defer proxy.StopProcesses(StopWaitForInflightRequest) + injectTestHandlers(proxy, nil) t.Run("no models loaded", func(t *testing.T) { req := httptest.NewRequest("GET", "/running", nil) @@ -730,21 +731,22 @@ func TestProxyManager_RunningEndpoint(t *testing.T) { // Verify extended fields are present assert.NotEmpty(t, response.Running[0].Cmd, "cmd should be populated") assert.NotEmpty(t, response.Running[0].Proxy, "proxy should be populated") - assert.Equal(t, -1, response.Running[0].TTL, "ttl should default to -1 (use globalTTL)") + assert.Equal(t, 0, response.Running[0].TTL, "ttl should default to globalTTL (0)") }) } func TestProxyManager_AudioTranscriptionHandler(t *testing.T) { - config := config.AddDefaultGroupToConfig(config.Config{ - HealthCheckTimeout: 15, - Models: map[string]config.ModelConfig{ - "TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"), - }, - LogLevel: "error", - }) + cfg := testConfigFromYAML(t, ` +healthCheckTimeout: 15 +logLevel: error +models: + TheExpectedModel: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond TheExpectedModel +`) - proxy := New(config) + proxy := New(cfg) defer proxy.StopProcesses(StopWaitForInflightRequest) + injectTestHandlers(proxy, nil) // Create a buffer with multipart form data var b bytes.Buffer @@ -785,19 +787,19 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) { // Test useModelName in configuration sends overrides what is sent to upstream func TestProxyManager_UseModelName(t *testing.T) { upstreamModelName := "upstreamModel" - modelConfig := getTestSimpleResponderConfig(upstreamModelName) - modelConfig.UseModelName = upstreamModelName - conf := config.AddDefaultGroupToConfig(config.Config{ - HealthCheckTimeout: 15, - Models: map[string]config.ModelConfig{ - "model1": modelConfig, - }, - LogLevel: "error", - }) + conf := testConfigFromYAML(t, fmt.Sprintf(` +healthCheckTimeout: 15 +logLevel: error +models: + model1: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond %s + useModelName: %s +`, upstreamModelName, upstreamModelName)) proxy := New(conf) defer proxy.StopProcesses(StopWaitForInflightRequest) + injectTestHandlers(proxy, nil) requestedModel := "model1" @@ -851,16 +853,17 @@ func TestProxyManager_UseModelName(t *testing.T) { } func TestProxyManager_AudioVoicesGETHandler(t *testing.T) { - conf := config.AddDefaultGroupToConfig(config.Config{ - HealthCheckTimeout: 15, - Models: map[string]config.ModelConfig{ - "model1": getTestSimpleResponderConfig("model1"), - }, - LogLevel: "error", - }) + conf := testConfigFromYAML(t, ` +healthCheckTimeout: 15 +logLevel: error +models: + model1: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 +`) proxy := New(conf) defer proxy.StopProcesses(StopWaitForInflightRequest) + injectTestHandlers(proxy, nil) t.Run("successful GET with model query param", func(t *testing.T) { req := httptest.NewRequest("GET", "/v1/audio/voices?model=model1", nil) @@ -888,13 +891,13 @@ func TestProxyManager_AudioVoicesGETHandler(t *testing.T) { } func TestProxyManager_CORSOptionsHandler(t *testing.T) { - config := config.AddDefaultGroupToConfig(config.Config{ - HealthCheckTimeout: 15, - Models: map[string]config.ModelConfig{ - "model1": getTestSimpleResponderConfig("model1"), - }, - LogLevel: "error", - }) + cfg := testConfigFromYAML(t, ` +healthCheckTimeout: 15 +logLevel: error +models: + model1: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 +`) tests := []struct { name string @@ -935,8 +938,9 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - proxy := New(config) + proxy := New(cfg) defer proxy.StopProcesses(StopWaitForInflightRequest) + injectTestHandlers(proxy, nil) req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil) for k, v := range tt.requestHeaders { @@ -956,19 +960,17 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) { } func TestProxyManager_Upstream(t *testing.T) { - configStr := fmt.Sprintf(` + cfg := testConfigFromYAML(t, ` logLevel: error models: model1: - cmd: %s -port ${PORT} -silent -respond model1 + cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 aliases: [model-alias] -`, getSimpleResponderPath()) +`) - config, err := config.LoadConfigFromReader(strings.NewReader(configStr)) - assert.NoError(t, err) - - proxy := New(config) + proxy := New(cfg) defer proxy.StopProcesses(StopWaitForInflightRequest) + injectTestHandlers(proxy, nil) t.Run("main model name", func(t *testing.T) { req := httptest.NewRequest("GET", "/upstream/model1/test", nil) rec := CreateTestResponseRecorder() @@ -987,16 +989,17 @@ models: } func TestProxyManager_ChatContentLength(t *testing.T) { - config := config.AddDefaultGroupToConfig(config.Config{ - HealthCheckTimeout: 15, - Models: map[string]config.ModelConfig{ - "model1": getTestSimpleResponderConfig("model1"), - }, - LogLevel: "error", - }) + cfg := testConfigFromYAML(t, ` +healthCheckTimeout: 15 +logLevel: error +models: + model1: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 +`) - proxy := New(config) + proxy := New(cfg) defer proxy.StopProcesses(StopWaitForInflightRequest) + injectTestHandlers(proxy, nil) reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1") req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) @@ -1011,23 +1014,19 @@ func TestProxyManager_ChatContentLength(t *testing.T) { } func TestProxyManager_FiltersStripParams(t *testing.T) { - modelConfig := getTestSimpleResponderConfig("model1") - modelConfig.Filters = config.ModelFilters{ - Filters: config.Filters{ - StripParams: "temperature, model, stream", - }, - } + cfg := testConfigFromYAML(t, ` +healthCheckTimeout: 15 +logLevel: error +models: + model1: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 + filters: + stripParams: "temperature, model, stream" +`) - config := config.AddDefaultGroupToConfig(config.Config{ - HealthCheckTimeout: 15, - LogLevel: "error", - Models: map[string]config.ModelConfig{ - "model1": modelConfig, - }, - }) - - proxy := New(config) + proxy := New(cfg) defer proxy.StopProcesses(StopWaitForInflightRequest) + injectTestHandlers(proxy, nil) reqBody := `{"model":"model1", "temperature":0.1, "x_param":"123", "y_param":"abc", "stream":true}` req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) w := CreateTestResponseRecorder() @@ -1048,11 +1047,11 @@ func TestProxyManager_FiltersStripParams(t *testing.T) { func TestProxyManager_FiltersSetParamsByID(t *testing.T) { // no explicit aliases — setParamsByID keys are auto-registered as aliases - configStr := strings.Replace(` + cfg := testConfigFromYAML(t, ` logLevel: error models: model1: - cmd: 'SRPATH --port ${PORT} --silent --respond model1' + cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 proxy: "http://127.0.0.1:${PORT}" filters: setParams: @@ -1062,15 +1061,11 @@ models: reasoning_effort: high "${MODEL_ID}:low": reasoning_effort: low -`, "SRPATH", simpleResponderPath, -1) - - cfg, err := config.LoadConfigFromReader(strings.NewReader(configStr)) - if !assert.NoError(t, err, "invalid test configuration") { - return - } +`) proxy := New(cfg) defer proxy.StopProcesses(StopWaitForInflightRequest) + injectTestHandlers(proxy, nil) tests := []struct { requestedModel string @@ -1102,15 +1097,15 @@ models: } func TestProxyManager_HealthEndpoint(t *testing.T) { - config := config.AddDefaultGroupToConfig(config.Config{ - HealthCheckTimeout: 15, - Models: map[string]config.ModelConfig{ - "model1": getTestSimpleResponderConfig("model1"), - }, - LogLevel: "error", - }) + cfg := testConfigFromYAML(t, ` +healthCheckTimeout: 15 +logLevel: error +models: + model1: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 +`) - proxy := New(config) + proxy := New(cfg) defer proxy.StopProcesses(StopWaitForInflightRequest) req := httptest.NewRequest("GET", "/health", nil) rec := CreateTestResponseRecorder() @@ -1121,16 +1116,17 @@ func TestProxyManager_HealthEndpoint(t *testing.T) { // Ensure the custom llama-server /completion endpoint proxies correctly func TestProxyManager_CompletionEndpoint(t *testing.T) { - config := config.AddDefaultGroupToConfig(config.Config{ - HealthCheckTimeout: 15, - Models: map[string]config.ModelConfig{ - "model1": getTestSimpleResponderConfig("model1"), - }, - LogLevel: "error", - }) + cfg := testConfigFromYAML(t, ` +healthCheckTimeout: 15 +logLevel: error +models: + model1: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 +`) - proxy := New(config) + proxy := New(cfg) defer proxy.StopProcesses(StopWaitForInflightRequest) + injectTestHandlers(proxy, nil) reqBody := `{"model":"model1"}` req := httptest.NewRequest("POST", "/completion", bytes.NewBufferString(reqBody)) @@ -1143,10 +1139,7 @@ func TestProxyManager_CompletionEndpoint(t *testing.T) { func TestProxyManager_StartupHooks(t *testing.T) { - // using real YAML as the configuration has gotten more complex - // is the right approach as LoadConfigFromReader() does a lot more - // than parse YAML now. Eventually migrate all tests to use this approach - configStr := strings.Replace(` + cfg := testConfigFromYAML(t, ` logLevel: error hooks: on_startup: @@ -1161,16 +1154,10 @@ groups: - model2 models: model1: - cmd: ${simpleresponderpath} --port ${PORT} --silent --respond model1 + cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 model2: - cmd: ${simpleresponderpath} --port ${PORT} --silent --respond model2 -`, "${simpleresponderpath}", simpleResponderPath, -1) - - // Create a test model configuration - config, err := config.LoadConfigFromReader(strings.NewReader(configStr)) - if !assert.NoError(t, err, "Invalid configuration") { - return - } + cmd: {{RESPONDER}} --port ${PORT} --silent --respond model2 +`) preloadChan := make(chan ModelPreloadedEvent, 2) // buffer for 2 expected events @@ -1181,7 +1168,7 @@ models: defer unsub() // Create the proxy which should trigger preloading - proxy := New(config) + proxy := New(cfg) defer proxy.StopProcesses(StopWaitForInflightRequest) for i := 0; i < 2; i++ { @@ -1201,16 +1188,17 @@ models: } func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) { - config := config.AddDefaultGroupToConfig(config.Config{ - HealthCheckTimeout: 15, - Models: map[string]config.ModelConfig{ - "model1": getTestSimpleResponderConfig("model1"), - "author/model": getTestSimpleResponderConfig("author/model"), - }, - LogLevel: "error", - }) + cfg := testConfigFromYAML(t, ` +healthCheckTimeout: 15 +logLevel: error +models: + model1: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 + author/model: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond author/model +`) - proxy := New(config) + proxy := New(cfg) defer proxy.StopProcesses(StopWaitForInflightRequest) endpoints := []string{ @@ -1252,15 +1240,15 @@ func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) { } func TestProxyManager_ProxiedStreamingEndpointReturnsNoBufferingHeader(t *testing.T) { - config := config.AddDefaultGroupToConfig(config.Config{ - HealthCheckTimeout: 15, - Models: map[string]config.ModelConfig{ - "streaming-model": getTestSimpleResponderConfig("streaming-model"), - }, - LogLevel: "error", - }) + cfg := testConfigFromYAML(t, ` +healthCheckTimeout: 15 +logLevel: error +models: + streaming-model: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond streaming-model +`) - proxy := New(config) + proxy := New(cfg) defer proxy.StopProcesses(StopWaitForInflightRequest) // Make a streaming request @@ -1277,13 +1265,13 @@ func TestProxyManager_ProxiedStreamingEndpointReturnsNoBufferingHeader(t *testin } func TestProxyManager_ApiGetVersion(t *testing.T) { - config := config.AddDefaultGroupToConfig(config.Config{ - HealthCheckTimeout: 15, - Models: map[string]config.ModelConfig{ - "model1": getTestSimpleResponderConfig("model1"), - }, - LogLevel: "error", - }) + cfg := testConfigFromYAML(t, ` +healthCheckTimeout: 15 +logLevel: error +models: + model1: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 +`) // Version test map versionTest := map[string]string{ @@ -1292,7 +1280,7 @@ func TestProxyManager_ApiGetVersion(t *testing.T) { "version": "v001", } - proxy := New(config) + proxy := New(cfg) proxy.SetVersion(versionTest["build_date"], versionTest["commit"], versionTest["version"]) defer proxy.StopProcesses(StopWaitForInflightRequest) @@ -1315,17 +1303,20 @@ func TestProxyManager_ApiGetVersion(t *testing.T) { } func TestProxyManager_APIKeyAuth(t *testing.T) { - testConfig := config.AddDefaultGroupToConfig(config.Config{ - HealthCheckTimeout: 15, - Models: map[string]config.ModelConfig{ - "model1": getTestSimpleResponderConfig("model1"), - }, - RequiredAPIKeys: []string{"valid-key-1", "valid-key-2"}, - LogLevel: "error", - }) + testConfig := testConfigFromYAML(t, ` +healthCheckTimeout: 15 +logLevel: error +apiKeys: + - valid-key-1 + - valid-key-2 +models: + model1: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 +`) proxy := New(testConfig) defer proxy.StopProcesses(StopImmediately) + injectTestHandlers(proxy, nil) t.Run("valid key in x-api-key header", func(t *testing.T) { reqBody := `{"model":"model1"}` @@ -1427,16 +1418,17 @@ func TestProxyManager_APIKeyAuth(t *testing.T) { func TestProxyManager_APIKeyAuth_Disabled(t *testing.T) { // Config without RequiredAPIKeys - auth should be disabled - testConfig := config.AddDefaultGroupToConfig(config.Config{ - HealthCheckTimeout: 15, - Models: map[string]config.ModelConfig{ - "model1": getTestSimpleResponderConfig("model1"), - }, - LogLevel: "error", - }) + testConfig := testConfigFromYAML(t, ` +healthCheckTimeout: 15 +logLevel: error +models: + model1: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 +`) proxy := New(testConfig) defer proxy.StopProcesses(StopImmediately) + injectTestHandlers(proxy, nil) t.Run("requests pass without API key when not configured", func(t *testing.T) { reqBody := `{"model":"model1"}` @@ -1460,8 +1452,7 @@ func TestProxyManager_PeerProxy_InferenceHandler(t *testing.T) { })) defer peerServer.Close() - // Create config with peers but no local model for "peer-model" - configStr := fmt.Sprintf(` + testConfig := testConfigFromYAML(t, fmt.Sprintf(` logLevel: error peers: test-peer: @@ -1470,14 +1461,12 @@ peers: - peer-model models: local-model: - cmd: %s -port ${PORT} -silent -respond local-model -`, peerServer.URL, getSimpleResponderPath()) - - testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr)) - assert.NoError(t, err) + cmd: {{RESPONDER}} --port ${PORT} --silent --respond local-model +`, peerServer.URL)) proxy := New(testConfig) defer proxy.StopProcesses(StopImmediately) + injectTestHandlers(proxy, nil) reqBody := `{"model":"peer-model"}` req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) @@ -1499,8 +1488,7 @@ models: })) defer peerServer.Close() - // Create config where "shared-model" exists both locally and on peer - configStr := fmt.Sprintf(` + testConfig := testConfigFromYAML(t, fmt.Sprintf(` logLevel: error peers: test-peer: @@ -1509,14 +1497,12 @@ peers: - shared-model models: shared-model: - cmd: %s -port ${PORT} -silent -respond local-response -`, peerServer.URL, getSimpleResponderPath()) - - testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr)) - assert.NoError(t, err) + cmd: {{RESPONDER}} --port ${PORT} --silent --respond local-response +`, peerServer.URL)) proxy := New(testConfig) defer proxy.StopProcesses(StopImmediately) + injectTestHandlers(proxy, map[string]string{"shared-model": "local-response"}) reqBody := `{"model":"shared-model"}` req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) @@ -1535,7 +1521,7 @@ models: })) defer peerServer.Close() - configStr := fmt.Sprintf(` + testConfig := testConfigFromYAML(t, fmt.Sprintf(` logLevel: error peers: test-peer: @@ -1544,14 +1530,12 @@ peers: - peer-model models: local-model: - cmd: %s -port ${PORT} -silent -respond local-model -`, peerServer.URL, getSimpleResponderPath()) - - testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr)) - assert.NoError(t, err) + cmd: {{RESPONDER}} --port ${PORT} --silent --respond local-model +`, peerServer.URL)) proxy := New(testConfig) defer proxy.StopProcesses(StopImmediately) + injectTestHandlers(proxy, nil) reqBody := `{"model":"unknown-model"}` req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) @@ -1572,7 +1556,7 @@ models: })) defer peerServer.Close() - configStr := fmt.Sprintf(` + testConfig := testConfigFromYAML(t, fmt.Sprintf(` logLevel: error peers: test-peer: @@ -1582,14 +1566,12 @@ peers: - peer-model models: local-model: - cmd: %s -port ${PORT} -silent -respond local-model -`, peerServer.URL, getSimpleResponderPath()) - - testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr)) - assert.NoError(t, err) + cmd: {{RESPONDER}} --port ${PORT} --silent --respond local-model +`, peerServer.URL)) proxy := New(testConfig) defer proxy.StopProcesses(StopImmediately) + injectTestHandlers(proxy, nil) reqBody := `{"model":"peer-model"}` req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) @@ -1601,16 +1583,17 @@ models: }) t.Run("no peers configured - unknown model returns error", func(t *testing.T) { - testConfig := config.AddDefaultGroupToConfig(config.Config{ - HealthCheckTimeout: 15, - Models: map[string]config.ModelConfig{ - "local-model": getTestSimpleResponderConfig("local-model"), - }, - LogLevel: "error", - }) + testConfig := testConfigFromYAML(t, ` +healthCheckTimeout: 15 +logLevel: error +models: + local-model: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond local-model +`) proxy := New(testConfig) defer proxy.StopProcesses(StopImmediately) + injectTestHandlers(proxy, nil) // peerProxy exists but has no peer models configured assert.False(t, proxy.peerProxy.HasPeerModel("unknown-model")) @@ -1632,7 +1615,7 @@ models: })) defer peerServer.Close() - configStr := fmt.Sprintf(` + testConfig := testConfigFromYAML(t, fmt.Sprintf(` logLevel: error peers: test-peer: @@ -1641,14 +1624,12 @@ peers: - peer-model models: local-model: - cmd: %s -port ${PORT} -silent -respond local-model -`, peerServer.URL, getSimpleResponderPath()) - - testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr)) - assert.NoError(t, err) + cmd: {{RESPONDER}} --port ${PORT} --silent --respond local-model +`, peerServer.URL)) proxy := New(testConfig) defer proxy.StopProcesses(StopImmediately) + injectTestHandlers(proxy, nil) reqBody := `{"model":"peer-model"}` req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) @@ -1661,16 +1642,17 @@ models: } func TestProxyManager_SdApiTxt2ImgRouting(t *testing.T) { - conf := config.AddDefaultGroupToConfig(config.Config{ - HealthCheckTimeout: 15, - Models: map[string]config.ModelConfig{ - "sd-model": getTestSimpleResponderConfig("sd-model"), - }, - LogLevel: "error", - }) + conf := testConfigFromYAML(t, ` +healthCheckTimeout: 15 +logLevel: error +models: + sd-model: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond sd-model +`) proxy := New(conf) defer proxy.StopProcesses(StopWaitForInflightRequest) + injectTestHandlers(proxy, nil) t.Run("successful txt2img with model", func(t *testing.T) { reqBody := `{"model":"sd-model","prompt":"a cat"}` @@ -1704,16 +1686,17 @@ func TestProxyManager_SdApiTxt2ImgRouting(t *testing.T) { } func TestProxyManager_SdApiGetLoras(t *testing.T) { - conf := config.AddDefaultGroupToConfig(config.Config{ - HealthCheckTimeout: 15, - Models: map[string]config.ModelConfig{ - "sd-model": getTestSimpleResponderConfig("sd-model"), - }, - LogLevel: "error", - }) + conf := testConfigFromYAML(t, ` +healthCheckTimeout: 15 +logLevel: error +models: + sd-model: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond sd-model +`) proxy := New(conf) defer proxy.StopProcesses(StopWaitForInflightRequest) + injectTestHandlers(proxy, nil) t.Run("successful GET loras with model query param", func(t *testing.T) { req := httptest.NewRequest("GET", "/sdapi/v1/loras?model=sd-model", nil)