mirror of
https://github.com/mostlygeek/llama-swap.git
synced 2026-06-09 06:46:34 +02:00
02e015fa49
This is a huge backend change that essentially started with rewriting the concurrency handling for processes and blew up to a refactor of the entire application. In short these are the improvements: **Better state and life cycle management:** Life cycle management of processes has always been the trickiest part of the code. Juggling mutex locks between multiple locations to reduce race conditions was complex. Too complex for my feeble brain to build a simple mental model around as llama-swap gained more features. All of that has been refactored. Most of the locks are gone, replaced with a single run() that owns all state changes. There is one place to start from now to understand and extend routing logic. The improved life cycle management makes it easier to implement more complex swap optimization strategies in the future like #727. **Collation of requests:** llama-swap previously handled requests and swapping in the order they came in. For example requests for models in this order ABCABC would result in 5 swaps. Now those requests are handled in this order AABBCC. The result is less time waiting for swap under a high churn request queue. This fixes #588 #612. A possible future enhancement is to support a starvation parameter so swap can be forced when models have been waiting too long. **Shared base implementation for groups and swap matrix:** During the refactor it became clear that much of the swapping logic was shared between these two implementations. That is not surprising considering the swap matrix was added many moons after groups. Now they share a common base and their specific swap strategies are implemented into the swapPlanner interface. Requests for bespoke or specific swapping scenarios is a common theme in the issues. Now users can implement whatever bespoke and weird swapping strategy they want in their own fork. Just ask your agent of choice to implement swapPlanner. I'll still remaining more conservative on what actually lands in core llama-swap and will continue to evaluate PRs if the changes is good for everyone or just one specific use case. **AI / Agentic Disclosure:** I paid very close attention to the low level swap concurrency design and implementation. It's important to keep that essential part reliable, boring and no surprises. Backwards compatibility was also maintained, even the one way non-exclusive group model loading behaviour that people have rightly pointed out be a weird design decision. With the underlying swap core done the web server, api and UI sitting on top were largely ported over with Claude Code and Opus 4.7 in multiple phases. If you're curious I kept the changes in docs/newrouter-todo.md. I did several passes to make sure things weren't left behind. However, even frontier LLMs at the time of this PR still make small decisions that don't make a lot of sense. They get shit wrong all the time, just in small subtle way. That said, there's likely to be some new bugs introduced with this massive refactor. I'm fairly confident that there's no major architectural flaws that would cause goal seeking agents to make dumb, ugly code decisions. For a little while the legacy llama-swap will be available under cmd/legacy/llama-swap. The plan is to eventually delete that entry point as well as the proxy package. On a bit of a personal note, this PR is exciting and a bit sad for me. I hand wrote much of the original code and this PR ultimately replaces much of it. While the old code served as a good reference for the agent to implement the new stuff it still a bit sad to eventually delete it all.
1233 lines
38 KiB
Go
1233 lines
38 KiB
Go
package proxy
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"io"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"os"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/mostlygeek/llama-swap/internal/config"
|
|
"github.com/mostlygeek/llama-swap/internal/event"
|
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
|
"github.com/mostlygeek/llama-swap/internal/perf"
|
|
"github.com/tidwall/gjson"
|
|
"github.com/tidwall/sjson"
|
|
)
|
|
|
|
const (
|
|
PROFILE_SPLIT_CHAR = ":"
|
|
)
|
|
|
|
type proxyCtxKey string
|
|
|
|
type InflightCounter struct {
|
|
mu sync.Mutex
|
|
total int
|
|
}
|
|
|
|
func newInflightCounter() *InflightCounter {
|
|
return &InflightCounter{}
|
|
}
|
|
|
|
func (ic *InflightCounter) Current() int {
|
|
ic.mu.Lock()
|
|
total := ic.total
|
|
ic.mu.Unlock()
|
|
return total
|
|
}
|
|
|
|
func (ic *InflightCounter) Increment() int {
|
|
ic.mu.Lock()
|
|
ic.total++
|
|
total := ic.total
|
|
ic.mu.Unlock()
|
|
return total
|
|
}
|
|
|
|
func (ic *InflightCounter) Decrement() int {
|
|
ic.mu.Lock()
|
|
if ic.total > 0 {
|
|
ic.total--
|
|
}
|
|
total := ic.total
|
|
ic.mu.Unlock()
|
|
return total
|
|
}
|
|
|
|
type ProxyManager struct {
|
|
sync.Mutex
|
|
|
|
config config.Config
|
|
ginEngine *gin.Engine
|
|
|
|
// logging
|
|
proxyLogger *logmon.Monitor
|
|
upstreamLogger *logmon.Monitor
|
|
muxLogger *logmon.Monitor
|
|
|
|
metricsMonitor *metricsMonitor
|
|
perfMonitor *perf.Monitor
|
|
|
|
processGroups map[string]*ProcessGroup
|
|
|
|
// matrix-based swap (mutually exclusive with processGroups)
|
|
matrix *Matrix
|
|
|
|
inFlightCounter *InflightCounter
|
|
|
|
// shutdown signaling
|
|
shutdownCtx context.Context
|
|
shutdownCancel context.CancelFunc
|
|
|
|
// version info
|
|
buildDate string
|
|
commit string
|
|
version string
|
|
|
|
// peer proxy see: #296, #433
|
|
peerProxy *PeerProxy
|
|
}
|
|
|
|
func New(proxyConfig config.Config) *ProxyManager {
|
|
// set up loggers
|
|
|
|
var muxLogger, upstreamLogger, proxyLogger *logmon.Monitor
|
|
switch proxyConfig.LogToStdout {
|
|
case config.LogToStdoutNone:
|
|
muxLogger = logmon.NewWriter(io.Discard)
|
|
upstreamLogger = logmon.NewWriter(io.Discard)
|
|
proxyLogger = logmon.NewWriter(io.Discard)
|
|
case config.LogToStdoutBoth:
|
|
muxLogger = logmon.NewWriter(os.Stdout)
|
|
upstreamLogger = logmon.NewWriter(muxLogger)
|
|
proxyLogger = logmon.NewWriter(muxLogger)
|
|
case config.LogToStdoutUpstream:
|
|
muxLogger = logmon.NewWriter(os.Stdout)
|
|
upstreamLogger = logmon.NewWriter(muxLogger)
|
|
proxyLogger = logmon.NewWriter(io.Discard)
|
|
default:
|
|
// same as config.LogToStdoutProxy
|
|
// helpful because some old tests create a config.Config directly and it
|
|
// may not have LogToStdout set explicitly
|
|
muxLogger = logmon.NewWriter(os.Stdout)
|
|
upstreamLogger = logmon.NewWriter(io.Discard)
|
|
proxyLogger = logmon.NewWriter(muxLogger)
|
|
}
|
|
|
|
if proxyConfig.LogRequests {
|
|
proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.")
|
|
}
|
|
|
|
switch strings.ToLower(strings.TrimSpace(proxyConfig.LogLevel)) {
|
|
case "debug":
|
|
proxyLogger.SetLogLevel(logmon.LevelDebug)
|
|
upstreamLogger.SetLogLevel(logmon.LevelDebug)
|
|
case "info":
|
|
proxyLogger.SetLogLevel(logmon.LevelInfo)
|
|
upstreamLogger.SetLogLevel(logmon.LevelInfo)
|
|
case "warn":
|
|
proxyLogger.SetLogLevel(logmon.LevelWarn)
|
|
upstreamLogger.SetLogLevel(logmon.LevelWarn)
|
|
case "error":
|
|
proxyLogger.SetLogLevel(logmon.LevelError)
|
|
upstreamLogger.SetLogLevel(logmon.LevelError)
|
|
default:
|
|
proxyLogger.SetLogLevel(logmon.LevelInfo)
|
|
upstreamLogger.SetLogLevel(logmon.LevelInfo)
|
|
}
|
|
|
|
// see: https://go.dev/src/time/format.go
|
|
timeFormats := map[string]string{
|
|
"ansic": time.ANSIC,
|
|
"unixdate": time.UnixDate,
|
|
"rubydate": time.RubyDate,
|
|
"rfc822": time.RFC822,
|
|
"rfc822z": time.RFC822Z,
|
|
"rfc850": time.RFC850,
|
|
"rfc1123": time.RFC1123,
|
|
"rfc1123z": time.RFC1123Z,
|
|
"rfc3339": time.RFC3339,
|
|
"rfc3339nano": time.RFC3339Nano,
|
|
"kitchen": time.Kitchen,
|
|
"stamp": time.Stamp,
|
|
"stampmilli": time.StampMilli,
|
|
"stampmicro": time.StampMicro,
|
|
"stampnano": time.StampNano,
|
|
}
|
|
|
|
if timeFormat, ok := timeFormats[strings.ToLower(strings.TrimSpace(proxyConfig.LogTimeFormat))]; ok {
|
|
proxyLogger.SetLogTimeFormat(timeFormat)
|
|
upstreamLogger.SetLogTimeFormat(timeFormat)
|
|
}
|
|
|
|
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
|
|
|
var maxMetrics int
|
|
if proxyConfig.MetricsMaxInMemory <= 0 {
|
|
maxMetrics = 1000 // Default fallback
|
|
} else {
|
|
maxMetrics = proxyConfig.MetricsMaxInMemory
|
|
}
|
|
|
|
peerProxy, err := NewPeerProxy(proxyConfig.Peers, proxyLogger)
|
|
if err != nil {
|
|
proxyLogger.Errorf("Disabling Peering. Failed to create proxy peers: %v", err)
|
|
peerProxy = nil
|
|
}
|
|
|
|
pm := &ProxyManager{
|
|
config: proxyConfig,
|
|
ginEngine: gin.New(),
|
|
|
|
proxyLogger: proxyLogger,
|
|
muxLogger: muxLogger,
|
|
upstreamLogger: upstreamLogger,
|
|
|
|
metricsMonitor: newMetricsMonitor(proxyLogger, maxMetrics, proxyConfig.CaptureBuffer),
|
|
|
|
processGroups: make(map[string]*ProcessGroup),
|
|
|
|
inFlightCounter: newInflightCounter(),
|
|
|
|
shutdownCtx: shutdownCtx,
|
|
shutdownCancel: shutdownCancel,
|
|
|
|
buildDate: "unknown",
|
|
commit: "abcd1234",
|
|
version: "0",
|
|
|
|
peerProxy: peerProxy,
|
|
}
|
|
|
|
// create either matrix or process groups (mutually exclusive)
|
|
if proxyConfig.Matrix != nil {
|
|
pm.matrix = NewMatrix(proxyConfig, proxyLogger, upstreamLogger)
|
|
} else {
|
|
for groupID := range proxyConfig.Groups {
|
|
processGroup := NewProcessGroup(groupID, proxyConfig, proxyLogger, upstreamLogger)
|
|
pm.processGroups[groupID] = processGroup
|
|
}
|
|
}
|
|
|
|
pm.setupGinEngine()
|
|
|
|
// run any startup hooks
|
|
if len(proxyConfig.Hooks.OnStartup.Preload) > 0 {
|
|
// do it in the background, don't block startup -- not sure if good idea yet
|
|
go func() {
|
|
discardWriter := &DiscardWriter{}
|
|
for _, preloadModelName := range proxyConfig.Hooks.OnStartup.Preload {
|
|
modelID, ok := proxyConfig.RealModelName(preloadModelName)
|
|
|
|
if !ok {
|
|
proxyLogger.Warnf("Preload model %s not found in config", preloadModelName)
|
|
continue
|
|
}
|
|
|
|
proxyLogger.Infof("Preloading model: %s", modelID)
|
|
|
|
var preloadErr error
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
|
|
if pm.matrix != nil {
|
|
preloadErr = pm.matrix.ProxyRequest(modelID, discardWriter, req)
|
|
} else {
|
|
processGroup, err := pm.swapProcessGroup(modelID)
|
|
if err != nil {
|
|
preloadErr = err
|
|
} else {
|
|
preloadErr = processGroup.ProxyRequest(modelID, discardWriter, req)
|
|
}
|
|
}
|
|
|
|
if preloadErr != nil {
|
|
event.Emit(ModelPreloadedEvent{
|
|
ModelName: modelID,
|
|
Success: false,
|
|
})
|
|
proxyLogger.Errorf("Failed to preload model %s: %v", modelID, preloadErr)
|
|
continue
|
|
} else {
|
|
event.Emit(ModelPreloadedEvent{
|
|
ModelName: modelID,
|
|
Success: true,
|
|
})
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
return pm
|
|
}
|
|
|
|
func (pm *ProxyManager) setupGinEngine() {
|
|
|
|
pm.ginEngine.Use(func(c *gin.Context) {
|
|
|
|
for _, prefix := range []string{
|
|
"/wol-health",
|
|
"/api/performance",
|
|
"/metrics",
|
|
} {
|
|
if strings.HasPrefix(c.Request.URL.Path, prefix) {
|
|
c.Next()
|
|
return
|
|
}
|
|
}
|
|
|
|
start := time.Now()
|
|
|
|
// capture these because /upstream/:model rewrites them in c.Next()
|
|
clientIP := c.ClientIP()
|
|
method := c.Request.Method
|
|
path := c.Request.URL.Path
|
|
|
|
c.Next()
|
|
|
|
duration := time.Since(start)
|
|
statusCode := c.Writer.Status()
|
|
bodySize := c.Writer.Size()
|
|
|
|
pm.proxyLogger.Infof("Request %s \"%s %s %s\" %d %d \"%s\" %v",
|
|
clientIP,
|
|
method,
|
|
path,
|
|
c.Request.Proto,
|
|
statusCode,
|
|
bodySize,
|
|
c.Request.UserAgent(),
|
|
duration,
|
|
)
|
|
})
|
|
|
|
// see: issue: #81, #77 and #42 for CORS issues
|
|
// respond with permissive OPTIONS for any endpoint
|
|
pm.ginEngine.Use(func(c *gin.Context) {
|
|
if c.Request.Method == "OPTIONS" {
|
|
c.Header("Access-Control-Allow-Origin", "*")
|
|
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
|
|
|
// allow whatever the client requested by default
|
|
if headers := c.Request.Header.Get("Access-Control-Request-Headers"); headers != "" {
|
|
sanitized := SanitizeAccessControlRequestHeaderValues(headers)
|
|
c.Header("Access-Control-Allow-Headers", sanitized)
|
|
} else {
|
|
c.Header(
|
|
"Access-Control-Allow-Headers",
|
|
"Content-Type, Authorization, Accept, X-Requested-With",
|
|
)
|
|
}
|
|
c.Header("Access-Control-Max-Age", "86400")
|
|
c.AbortWithStatus(http.StatusNoContent)
|
|
return
|
|
}
|
|
c.Next()
|
|
})
|
|
|
|
// Set up routes using the Gin engine
|
|
// Protected routes use pm.apiKeyAuth() middleware
|
|
llmHandler := pm.mkProxyJSONHandler(captureAll)
|
|
pm.ginEngine.POST("/v1/chat/completions", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
|
pm.ginEngine.POST("/v1/responses", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
|
// Support legacy /v1/completions api, see issue #12
|
|
pm.ginEngine.POST("/v1/completions", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
|
// Support anthropic /v1/messages (added https://github.com/ggml-org/llama.cpp/pull/17570)
|
|
pm.ginEngine.POST("/v1/messages", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
|
// Support anthropic count_tokens API (Also added in the above PR)
|
|
pm.ginEngine.POST("/v1/messages/count_tokens", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
|
|
|
// Support embeddings and reranking
|
|
pm.ginEngine.POST("/v1/embeddings", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
|
|
|
// llama-server's /reranking endpoint + aliases
|
|
pm.ginEngine.POST("/reranking", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
|
pm.ginEngine.POST("/rerank", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
|
pm.ginEngine.POST("/v1/rerank", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
|
pm.ginEngine.POST("/v1/reranking", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
|
|
|
// Unversioned API endpoints, see issue #728
|
|
pm.ginEngine.POST("/v/chat/completions", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
|
pm.ginEngine.POST("/v/responses", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
|
pm.ginEngine.POST("/v/completions", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
|
pm.ginEngine.POST("/v/messages", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
|
pm.ginEngine.POST("/v/messages/count_tokens", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
|
pm.ginEngine.POST("/v/embeddings", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
|
pm.ginEngine.POST("/v/rerank", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
|
pm.ginEngine.POST("/v/reranking", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
|
|
|
// llama-server's /infill endpoint for code infilling
|
|
pm.ginEngine.POST("/infill", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
|
|
|
// llama-server's /completion endpoint
|
|
pm.ginEngine.POST("/completion", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
|
|
|
// Support audio/speech endpoint
|
|
pm.ginEngine.POST(
|
|
"/v1/audio/speech",
|
|
pm.apiKeyAuth(),
|
|
pm.trackInflight(),
|
|
pm.mkProxyJSONHandler(captureReqAll|captureRespHeaders),
|
|
)
|
|
pm.ginEngine.POST(
|
|
"/v1/audio/voices",
|
|
pm.apiKeyAuth(),
|
|
pm.trackInflight(),
|
|
pm.mkProxyJSONHandler(captureReqHeaders|captureRespAll),
|
|
)
|
|
pm.ginEngine.GET("/v1/audio/voices", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyGETModelHandler)
|
|
|
|
pm.ginEngine.POST(
|
|
"/v1/audio/transcriptions",
|
|
pm.apiKeyAuth(),
|
|
pm.trackInflight(),
|
|
pm.mkPostFormHandler(captureReqHeaders|captureRespHeaders|captureRespBody),
|
|
)
|
|
pm.ginEngine.POST(
|
|
"/v1/images/generations",
|
|
pm.apiKeyAuth(),
|
|
pm.trackInflight(),
|
|
pm.mkProxyJSONHandler(captureReqAll|captureRespHeaders),
|
|
)
|
|
|
|
pm.ginEngine.POST(
|
|
"/v1/images/edits",
|
|
pm.apiKeyAuth(),
|
|
pm.trackInflight(),
|
|
pm.mkPostFormHandler(captureReqHeaders|captureRespHeaders),
|
|
)
|
|
|
|
// sd.cpp /sdapi/v1 endpoints
|
|
pm.ginEngine.POST("/sdapi/v1/txt2img",
|
|
pm.apiKeyAuth(),
|
|
pm.trackInflight(),
|
|
pm.mkProxyJSONHandler(captureReqAll|captureRespHeaders),
|
|
)
|
|
pm.ginEngine.POST("/sdapi/v1/img2img",
|
|
pm.apiKeyAuth(),
|
|
pm.trackInflight(),
|
|
pm.mkProxyJSONHandler(captureReqHeaders|captureRespHeaders),
|
|
)
|
|
pm.ginEngine.GET("/sdapi/v1/loras", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyGETModelHandler)
|
|
|
|
pm.ginEngine.GET("/v1/models", pm.apiKeyAuth(), pm.listModelsHandler)
|
|
|
|
// in proxymanager_loghandlers.go
|
|
pm.ginEngine.GET("/logs", pm.apiKeyAuth(), pm.sendLogsHandlers)
|
|
pm.ginEngine.GET("/logs/stream", pm.apiKeyAuth(), pm.streamLogsHandler)
|
|
pm.ginEngine.GET("/logs/stream/*logMonitorID", pm.apiKeyAuth(), pm.streamLogsHandler)
|
|
|
|
/**
|
|
* User Interface Endpoints
|
|
*/
|
|
pm.ginEngine.GET("/", func(c *gin.Context) {
|
|
c.Redirect(http.StatusFound, "/ui")
|
|
})
|
|
|
|
pm.ginEngine.GET("/upstream", func(c *gin.Context) {
|
|
c.Redirect(http.StatusFound, "/ui/models")
|
|
})
|
|
pm.ginEngine.Any("/upstream/*upstreamPath", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyToUpstream)
|
|
pm.ginEngine.GET("/unload", pm.apiKeyAuth(), pm.unloadAllModelsHandler)
|
|
pm.ginEngine.GET("/running", pm.apiKeyAuth(), pm.listRunningProcessesHandler)
|
|
pm.ginEngine.GET("/health", func(c *gin.Context) {
|
|
c.String(http.StatusOK, "OK")
|
|
})
|
|
|
|
pm.ginEngine.GET("/metrics", pm.prometheusMetricsHandler)
|
|
|
|
// see cmd/wol-proxy/wol-proxy.go, not logged
|
|
pm.ginEngine.GET("/wol-health", func(c *gin.Context) {
|
|
c.String(http.StatusOK, "OK")
|
|
})
|
|
|
|
pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) {
|
|
if data, err := reactStaticFS.ReadFile("ui_dist/favicon.ico"); err == nil {
|
|
c.Data(http.StatusOK, "image/x-icon", data)
|
|
} else {
|
|
c.String(http.StatusInternalServerError, err.Error())
|
|
}
|
|
})
|
|
|
|
reactFS, err := GetReactFS()
|
|
if err != nil {
|
|
pm.proxyLogger.Errorf("Failed to load React filesystem: %v", err)
|
|
} else {
|
|
// Serve files with compression support under /ui/*
|
|
// This handler checks for pre-compressed .br and .gz files
|
|
pm.ginEngine.GET("/ui/*filepath", func(c *gin.Context) {
|
|
filepath := strings.TrimPrefix(c.Param("filepath"), "/")
|
|
// Default to index.html for directory-like paths
|
|
if filepath == "" {
|
|
filepath = "index.html"
|
|
}
|
|
|
|
ServeCompressedFile(reactFS, c.Writer, c.Request, filepath)
|
|
})
|
|
|
|
// Serve SPA for UI under /ui/* - fallback to index.html for client-side routing
|
|
pm.ginEngine.NoRoute(func(c *gin.Context) {
|
|
if !strings.HasPrefix(c.Request.URL.Path, "/ui") {
|
|
c.AbortWithStatus(http.StatusNotFound)
|
|
return
|
|
}
|
|
|
|
// Check if this looks like a file request (has extension)
|
|
path := c.Request.URL.Path
|
|
if strings.Contains(path, ".") && !strings.HasSuffix(path, "/") {
|
|
// This was likely a file request that wasn't found
|
|
c.AbortWithStatus(http.StatusNotFound)
|
|
return
|
|
}
|
|
|
|
// Serve index.html for SPA routing
|
|
ServeCompressedFile(reactFS, c.Writer, c.Request, "index.html")
|
|
})
|
|
}
|
|
|
|
// see: proxymanager_api.go
|
|
// add API handler functions
|
|
addApiHandlers(pm)
|
|
|
|
// Disable console color for testing
|
|
gin.DisableConsoleColor()
|
|
}
|
|
|
|
func (pm *ProxyManager) trackInflight() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
event.Emit(InFlightRequestsEvent{Total: pm.inFlightCounter.Increment()})
|
|
defer event.Emit(InFlightRequestsEvent{Total: pm.inFlightCounter.Decrement()})
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// ServeHTTP implements http.Handler interface
|
|
func (pm *ProxyManager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
pm.ginEngine.ServeHTTP(w, r)
|
|
}
|
|
|
|
// StopProcesses acquires a lock and stops all running upstream processes.
|
|
// This is the public method safe for concurrent calls.
|
|
// Unlike Shutdown, this method only stops the processes but doesn't perform
|
|
// a complete shutdown, allowing for process replacement without full termination.
|
|
func (pm *ProxyManager) StopProcesses(strategy StopStrategy) {
|
|
pm.Lock()
|
|
defer pm.Unlock()
|
|
|
|
if pm.matrix != nil {
|
|
pm.matrix.StopProcesses(strategy)
|
|
return
|
|
}
|
|
|
|
// stop Processes in parallel
|
|
var wg sync.WaitGroup
|
|
for _, processGroup := range pm.processGroups {
|
|
wg.Add(1)
|
|
go func(processGroup *ProcessGroup) {
|
|
defer wg.Done()
|
|
processGroup.StopProcesses(strategy)
|
|
}(processGroup)
|
|
}
|
|
|
|
wg.Wait()
|
|
}
|
|
|
|
// Shutdown stops all processes managed by this ProxyManager
|
|
func (pm *ProxyManager) Shutdown() {
|
|
pm.Lock()
|
|
defer pm.Unlock()
|
|
|
|
pm.proxyLogger.Debug("Shutdown() called in proxy manager")
|
|
|
|
if pm.matrix != nil {
|
|
pm.matrix.Shutdown()
|
|
pm.shutdownCancel()
|
|
return
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
// Send shutdown signal to all process in groups
|
|
for _, processGroup := range pm.processGroups {
|
|
wg.Add(1)
|
|
go func(processGroup *ProcessGroup) {
|
|
defer wg.Done()
|
|
processGroup.Shutdown()
|
|
}(processGroup)
|
|
}
|
|
wg.Wait()
|
|
pm.shutdownCancel()
|
|
}
|
|
|
|
func (pm *ProxyManager) swapProcessGroup(realModelName string) (*ProcessGroup, error) {
|
|
processGroup := pm.findGroupByModelName(realModelName)
|
|
if processGroup == nil {
|
|
return nil, fmt.Errorf("could not find process group for model %s", realModelName)
|
|
}
|
|
|
|
if processGroup.exclusive {
|
|
pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id)
|
|
for groupId, otherGroup := range pm.processGroups {
|
|
if groupId != processGroup.id && !otherGroup.persistent {
|
|
otherGroup.StopProcesses(StopWaitForInflightRequest)
|
|
}
|
|
}
|
|
}
|
|
|
|
return processGroup, nil
|
|
}
|
|
|
|
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
|
data := make([]gin.H, 0, len(pm.config.Models))
|
|
createdTime := time.Now().Unix()
|
|
|
|
newRecord := func(modelId string, modelConfig config.ModelConfig) gin.H {
|
|
record := gin.H{
|
|
"id": modelId,
|
|
"object": "model",
|
|
"created": createdTime,
|
|
"owned_by": "llama-swap",
|
|
}
|
|
|
|
if name := strings.TrimSpace(modelConfig.Name); name != "" {
|
|
record["name"] = name
|
|
}
|
|
if desc := strings.TrimSpace(modelConfig.Description); desc != "" {
|
|
record["description"] = desc
|
|
}
|
|
|
|
// Add metadata if present
|
|
if len(modelConfig.Metadata) > 0 {
|
|
record["meta"] = gin.H{
|
|
"llamaswap": modelConfig.Metadata,
|
|
}
|
|
}
|
|
return record
|
|
}
|
|
|
|
for id, modelConfig := range pm.config.Models {
|
|
if modelConfig.Unlisted {
|
|
continue
|
|
}
|
|
|
|
data = append(data, newRecord(id, modelConfig))
|
|
|
|
// Include aliases
|
|
if pm.config.IncludeAliasesInList {
|
|
for _, alias := range modelConfig.Aliases {
|
|
if alias := strings.TrimSpace(alias); alias != "" {
|
|
data = append(data, newRecord(alias, modelConfig))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if pm.peerProxy != nil {
|
|
for peerID, peer := range pm.peerProxy.ListPeers() {
|
|
// add peer models
|
|
for _, modelID := range peer.Models {
|
|
// Skip unlisted models if not showing them
|
|
record := newRecord(modelID, config.ModelConfig{
|
|
Name: fmt.Sprintf("%s: %s", peerID, modelID),
|
|
Metadata: map[string]any{
|
|
"peerID": peerID,
|
|
},
|
|
})
|
|
|
|
data = append(data, record)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Sort by the "id" key
|
|
sort.Slice(data, func(i, j int) bool {
|
|
si, _ := data[i]["id"].(string)
|
|
sj, _ := data[j]["id"].(string)
|
|
return si < sj
|
|
})
|
|
|
|
// Set CORS headers if origin exists
|
|
if origin := c.GetHeader("Origin"); origin != "" {
|
|
c.Header("Access-Control-Allow-Origin", origin)
|
|
}
|
|
|
|
// Use gin's JSON method which handles content-type and encoding
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"object": "list",
|
|
"data": data,
|
|
})
|
|
}
|
|
|
|
// findModelInPath searches for a valid model name in a path with slashes.
|
|
// It iteratively builds up path segments until it finds a matching model.
|
|
// Returns: (searchModelName, realModelName, remainingPath, found)
|
|
// Example: "/author/model/endpoint" with model "author/model" -> ("author/model", "author/model", "/endpoint", true)
|
|
func (pm *ProxyManager) findModelInPath(path string) (searchName string, realName string, remainingPath string, found bool) {
|
|
parts := strings.Split(strings.TrimSpace(path), "/")
|
|
searchModelName := ""
|
|
|
|
for i, part := range parts {
|
|
if part == "" {
|
|
continue
|
|
}
|
|
|
|
if searchModelName == "" {
|
|
searchModelName = part
|
|
} else {
|
|
searchModelName = searchModelName + "/" + part
|
|
}
|
|
|
|
if modelID, ok := pm.config.RealModelName(searchModelName); ok {
|
|
return searchModelName, modelID, "/" + strings.Join(parts[i+1:], "/"), true
|
|
}
|
|
}
|
|
|
|
return "", "", "", false
|
|
}
|
|
|
|
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
|
upstreamPath := c.Param("upstreamPath")
|
|
|
|
searchModelName, modelID, remainingPath, modelFound := pm.findModelInPath(upstreamPath)
|
|
|
|
if !modelFound {
|
|
pm.sendErrorResponse(c, http.StatusNotFound, "model not found")
|
|
return
|
|
}
|
|
|
|
// Redirect /upstream/modelname to /upstream/modelname/ for URL consistency.
|
|
// This ensures relative URLs in upstream responses resolve correctly and
|
|
// provides canonical URL form. Uses 308 for POST/PUT/etc to preserve the
|
|
// HTTP method (301 would downgrade to GET).
|
|
if remainingPath == "/" && !strings.HasSuffix(upstreamPath, "/") {
|
|
newPath := "/upstream/" + searchModelName + "/"
|
|
if c.Request.URL.RawQuery != "" {
|
|
newPath += "?" + c.Request.URL.RawQuery
|
|
}
|
|
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead {
|
|
c.Redirect(http.StatusMovedPermanently, newPath)
|
|
} else {
|
|
c.Redirect(http.StatusPermanentRedirect, newPath)
|
|
}
|
|
return
|
|
}
|
|
|
|
var handler func(string, http.ResponseWriter, *http.Request) error
|
|
if pm.matrix != nil {
|
|
handler = pm.matrix.ProxyRequest
|
|
} else {
|
|
processGroup, err := pm.swapProcessGroup(modelID)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
|
return
|
|
}
|
|
handler = processGroup.ProxyRequest
|
|
}
|
|
|
|
// rewrite the path
|
|
originalPath := c.Request.URL.Path
|
|
c.Request.URL.Path = remainingPath
|
|
|
|
// attempt to record metrics if it is a POST request
|
|
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
|
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, captureNone, handler); err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
|
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", modelID, originalPath)
|
|
return
|
|
}
|
|
} else {
|
|
if err := handler(modelID, c.Writer, c.Request); err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
|
pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", modelID, originalPath)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (pm *ProxyManager) mkProxyJSONHandler(cf captureFields) func(*gin.Context) {
|
|
return func(c *gin.Context) {
|
|
bodyBytes, err := io.ReadAll(c.Request.Body)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
|
|
return
|
|
}
|
|
|
|
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
|
if requestedModel == "" {
|
|
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
|
return
|
|
}
|
|
|
|
// Look for a matching local model first
|
|
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
|
|
|
|
modelID, found := pm.config.RealModelName(requestedModel)
|
|
if found {
|
|
var localHandler func(string, http.ResponseWriter, *http.Request) error
|
|
if pm.matrix != nil {
|
|
localHandler = pm.matrix.ProxyRequest
|
|
} else {
|
|
processGroup, err := pm.swapProcessGroup(modelID)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
|
return
|
|
}
|
|
localHandler = processGroup.ProxyRequest
|
|
}
|
|
|
|
// issue #69 allow custom model names to be sent to upstream
|
|
useModelName := pm.config.Models[modelID].UseModelName
|
|
if useModelName != "" {
|
|
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
|
|
return
|
|
}
|
|
}
|
|
|
|
// issue #174 strip parameters from the JSON body
|
|
stripParams, err := pm.config.Models[modelID].Filters.SanitizedStripParams()
|
|
if err != nil { // just log it and continue
|
|
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[modelID].Filters.StripParams, err.Error())
|
|
} else {
|
|
for _, param := range stripParams {
|
|
pm.proxyLogger.Debugf("<%s> stripping param: %s", modelID, param)
|
|
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// issue #453 set/override parameters in the JSON body
|
|
setParams, setParamKeys := pm.config.Models[modelID].Filters.SanitizedSetParams()
|
|
for _, key := range setParamKeys {
|
|
pm.proxyLogger.Debugf("<%s> setting param: %s", modelID, key)
|
|
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
|
|
return
|
|
}
|
|
}
|
|
|
|
// setParamsByID: set params based on the requested model ID (runs after setParams, can override it)
|
|
setParamsByIDParams, setParamsByIDKeys := pm.config.Models[modelID].Filters.SanitizedSetParamsByID(requestedModel)
|
|
for _, key := range setParamsByIDKeys {
|
|
pm.proxyLogger.Debugf("<%s> setting param by id: %s", requestedModel, key)
|
|
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParamsByIDParams[key])
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
|
|
return
|
|
}
|
|
}
|
|
|
|
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
|
nextHandler = localHandler
|
|
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
|
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
|
modelID = requestedModel
|
|
|
|
// issue #453 apply filters for peer requests
|
|
peerFilters := pm.peerProxy.GetPeerFilters(requestedModel)
|
|
|
|
// Apply stripParams - remove specified parameters from request
|
|
stripParams := peerFilters.SanitizedStripParams()
|
|
for _, param := range stripParams {
|
|
pm.proxyLogger.Debugf("<%s> stripping param: %s", requestedModel, param)
|
|
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stripping parameter %s from request", param))
|
|
return
|
|
}
|
|
}
|
|
|
|
// Apply setParams - set/override specified parameters in request
|
|
setParams, setParamKeys := peerFilters.SanitizedSetParams()
|
|
for _, key := range setParamKeys {
|
|
pm.proxyLogger.Debugf("<%s> setting param: %s", requestedModel, key)
|
|
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
|
|
return
|
|
}
|
|
}
|
|
|
|
nextHandler = pm.peerProxy.ProxyRequest
|
|
}
|
|
|
|
if nextHandler == nil {
|
|
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable inference handler for %s", requestedModel))
|
|
return
|
|
}
|
|
|
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
|
|
|
// dechunk it as we already have all the body bytes see issue #11
|
|
c.Request.Header.Del("transfer-encoding")
|
|
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
|
|
c.Request.ContentLength = int64(len(bodyBytes))
|
|
|
|
// issue #728 support versionless API requests
|
|
if strings.HasPrefix(c.Request.URL.Path, "/v/") {
|
|
c.Request.URL.Path = strings.TrimPrefix(c.Request.URL.Path, "/v")
|
|
}
|
|
|
|
// issue #366 extract values that downstream handlers may need
|
|
isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool()
|
|
ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming)
|
|
ctx = context.WithValue(ctx, proxyCtxKey("model"), modelID)
|
|
c.Request = c.Request.WithContext(ctx)
|
|
|
|
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
|
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, cf, nextHandler); err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
|
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request model %s", modelID)
|
|
return
|
|
}
|
|
} else {
|
|
if err := nextHandler(modelID, c.Writer, c.Request); err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
|
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// mkPostFormHandler creates a POST form handler for inference backends
|
|
// with a custom captureFields to filter out large binary requests or responses.
|
|
func (pm *ProxyManager) mkPostFormHandler(cf captureFields) func(*gin.Context) {
|
|
return func(c *gin.Context) {
|
|
// Parse multipart form
|
|
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
|
|
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
|
return
|
|
}
|
|
|
|
// Get model parameter from the form
|
|
requestedModel := c.Request.FormValue("model")
|
|
if requestedModel == "" {
|
|
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' parameter in form data")
|
|
return
|
|
}
|
|
|
|
// Look for a matching local model first, then check peers
|
|
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
|
|
var useModelName string
|
|
|
|
modelID, found := pm.config.RealModelName(requestedModel)
|
|
if found {
|
|
if pm.matrix != nil {
|
|
nextHandler = pm.matrix.ProxyRequest
|
|
} else {
|
|
processGroup, err := pm.swapProcessGroup(modelID)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
|
return
|
|
}
|
|
nextHandler = processGroup.ProxyRequest
|
|
}
|
|
|
|
useModelName = pm.config.Models[modelID].UseModelName
|
|
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
|
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
|
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
|
modelID = requestedModel
|
|
nextHandler = pm.peerProxy.ProxyRequest
|
|
}
|
|
|
|
if nextHandler == nil {
|
|
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable handler for %s", requestedModel))
|
|
return
|
|
}
|
|
|
|
// We need to reconstruct the multipart form in any case since the body is consumed
|
|
// Create a new buffer for the reconstructed request
|
|
var requestBuffer bytes.Buffer
|
|
multipartWriter := multipart.NewWriter(&requestBuffer)
|
|
|
|
// Copy all form values
|
|
for key, values := range c.Request.MultipartForm.Value {
|
|
for _, value := range values {
|
|
fieldValue := value
|
|
// If this is the model field and we have a profile, use just the model name
|
|
if key == "model" {
|
|
// # issue #69 allow custom model names to be sent to upstream
|
|
if useModelName != "" {
|
|
fieldValue = useModelName
|
|
} else {
|
|
fieldValue = requestedModel
|
|
}
|
|
}
|
|
field, err := multipartWriter.CreateFormField(key)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field")
|
|
return
|
|
}
|
|
if _, err = field.Write([]byte(fieldValue)); err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error writing form field")
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// Copy all files from the original request
|
|
for key, fileHeaders := range c.Request.MultipartForm.File {
|
|
for _, fileHeader := range fileHeaders {
|
|
formFile, err := multipartWriter.CreateFormFile(key, fileHeader.Filename)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form file")
|
|
return
|
|
}
|
|
|
|
file, err := fileHeader.Open()
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error opening uploaded file")
|
|
return
|
|
}
|
|
|
|
if _, err = io.Copy(formFile, file); err != nil {
|
|
file.Close()
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error copying file data")
|
|
return
|
|
}
|
|
file.Close()
|
|
}
|
|
}
|
|
|
|
// Close the multipart writer to finalize the form
|
|
if err := multipartWriter.Close(); err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error finalizing multipart form")
|
|
return
|
|
}
|
|
|
|
// Create a new request with the reconstructed form data
|
|
modifiedReq, err := http.NewRequestWithContext(
|
|
c.Request.Context(),
|
|
c.Request.Method,
|
|
c.Request.URL.String(),
|
|
&requestBuffer,
|
|
)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error creating modified request")
|
|
return
|
|
}
|
|
|
|
// Copy the headers from the original request
|
|
modifiedReq.Header = c.Request.Header.Clone()
|
|
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
|
|
|
|
// set the content length of the body
|
|
modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len()))
|
|
modifiedReq.ContentLength = int64(requestBuffer.Len())
|
|
|
|
// Use the modified request for proxying
|
|
if pm.metricsMonitor != nil {
|
|
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, modifiedReq, cf, nextHandler); err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
|
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
|
|
return
|
|
}
|
|
} else {
|
|
if err := nextHandler(modelID, c.Writer, modifiedReq); err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
|
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (pm *ProxyManager) proxyGETModelHandler(c *gin.Context) {
|
|
requestedModel := c.Query("model")
|
|
if requestedModel == "" {
|
|
pm.sendErrorResponse(c, http.StatusBadRequest, "missing required 'model' query parameter")
|
|
return
|
|
}
|
|
|
|
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
|
|
var modelID string
|
|
|
|
if realModelID, found := pm.config.RealModelName(requestedModel); found {
|
|
modelID = realModelID
|
|
if pm.matrix != nil {
|
|
nextHandler = pm.matrix.ProxyRequest
|
|
} else {
|
|
processGroup, err := pm.swapProcessGroup(realModelID)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
|
return
|
|
}
|
|
nextHandler = processGroup.ProxyRequest
|
|
}
|
|
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
|
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
|
modelID = requestedModel
|
|
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
|
nextHandler = pm.peerProxy.ProxyRequest
|
|
}
|
|
|
|
if nextHandler == nil {
|
|
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable handler for %s", requestedModel))
|
|
return
|
|
}
|
|
|
|
if err := nextHandler(modelID, c.Writer, c.Request); err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
|
pm.proxyLogger.Errorf("Error Proxying GET Request for model %s", modelID)
|
|
return
|
|
}
|
|
}
|
|
|
|
func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) {
|
|
acceptHeader := c.GetHeader("Accept")
|
|
|
|
if strings.Contains(acceptHeader, "application/json") {
|
|
c.JSON(statusCode, gin.H{"error": message})
|
|
} else {
|
|
c.String(statusCode, message)
|
|
}
|
|
}
|
|
|
|
// apiKeyAuth returns a middleware that validates API keys if configured.
|
|
// Returns a pass-through handler if no API keys are configured.
|
|
func (pm *ProxyManager) apiKeyAuth() gin.HandlerFunc {
|
|
if len(pm.config.RequiredAPIKeys) == 0 {
|
|
return func(c *gin.Context) { c.Next() }
|
|
}
|
|
|
|
return func(c *gin.Context) {
|
|
xApiKey := c.GetHeader("x-api-key")
|
|
|
|
var bearerKey string
|
|
var basicKey string
|
|
if auth := c.GetHeader("Authorization"); auth != "" {
|
|
if strings.HasPrefix(auth, "Bearer ") {
|
|
bearerKey = strings.TrimPrefix(auth, "Bearer ")
|
|
} else if strings.HasPrefix(auth, "Basic ") {
|
|
// Basic Auth: base64(username:password), password is the API key
|
|
encoded := strings.TrimPrefix(auth, "Basic ")
|
|
if decoded, err := base64.StdEncoding.DecodeString(encoded); err == nil {
|
|
parts := strings.SplitN(string(decoded), ":", 2)
|
|
if len(parts) == 2 {
|
|
basicKey = parts[1] // password is the API key
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Use first key found: Basic, then Bearer, then x-api-key
|
|
var providedKey string
|
|
if basicKey != "" {
|
|
providedKey = basicKey
|
|
} else if bearerKey != "" {
|
|
providedKey = bearerKey
|
|
} else {
|
|
providedKey = xApiKey
|
|
}
|
|
|
|
// Validate key
|
|
valid := false
|
|
for _, key := range pm.config.RequiredAPIKeys {
|
|
if providedKey == key {
|
|
valid = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !valid {
|
|
c.Header("WWW-Authenticate", `Basic realm="llama-swap"`)
|
|
pm.sendErrorResponse(c, http.StatusUnauthorized, "unauthorized: invalid or missing API key")
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
// Strip auth headers to prevent leakage to upstream
|
|
c.Request.Header.Del("Authorization")
|
|
c.Request.Header.Del("x-api-key")
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
|
|
pm.StopProcesses(StopImmediately)
|
|
c.String(http.StatusOK, "OK")
|
|
}
|
|
|
|
func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
|
|
context.Header("Content-Type", "application/json")
|
|
runningProcesses := make([]gin.H, 0) // Default to an empty response.
|
|
|
|
if pm.matrix != nil {
|
|
for _, modelID := range pm.matrix.RunningModels() {
|
|
if process, ok := pm.matrix.GetProcess(modelID); ok {
|
|
runningProcesses = append(runningProcesses, gin.H{
|
|
"model": process.ID,
|
|
"state": process.CurrentState(),
|
|
"cmd": process.config.Cmd,
|
|
"proxy": process.config.Proxy,
|
|
"ttl": process.config.UnloadAfter,
|
|
"name": process.config.Name,
|
|
"description": process.config.Description,
|
|
})
|
|
}
|
|
}
|
|
} else {
|
|
for _, processGroup := range pm.processGroups {
|
|
for _, process := range processGroup.processes {
|
|
if process.CurrentState() == StateReady {
|
|
runningProcesses = append(runningProcesses, gin.H{
|
|
"model": process.ID,
|
|
"state": process.CurrentState(),
|
|
"cmd": process.config.Cmd,
|
|
"proxy": process.config.Proxy,
|
|
"ttl": process.config.UnloadAfter,
|
|
"name": process.config.Name,
|
|
"description": process.config.Description,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Put the results under the `running` key.
|
|
response := gin.H{
|
|
"running": runningProcesses,
|
|
}
|
|
|
|
context.JSON(http.StatusOK, response) // Always return 200 OK
|
|
}
|
|
|
|
func (pm *ProxyManager) findGroupByModelName(modelName string) *ProcessGroup {
|
|
for _, group := range pm.processGroups {
|
|
if group.HasMember(modelName) {
|
|
return group
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (pm *ProxyManager) SetVersion(buildDate string, commit string, version string) {
|
|
pm.Lock()
|
|
defer pm.Unlock()
|
|
pm.buildDate = buildDate
|
|
pm.commit = commit
|
|
pm.version = version
|
|
}
|
|
|
|
func (pm *ProxyManager) SetPerfMonitor(m *perf.Monitor) {
|
|
pm.Lock()
|
|
defer pm.Unlock()
|
|
pm.perfMonitor = m
|
|
}
|