mirror of
https://github.com/mostlygeek/llama-swap.git
synced 2026-06-09 14:56:34 +02:00
4635b3448c
Extend the peer proxy to support setParamsByID filters, matching the
existing capability for local models. The ${MODEL_ID} macro in
setParamsByID keys is expanded per-model so a single peer config
covering multiple models generates per-model aliases.
- Expand global macros in peer filters.setParamsByID during config load
- Validate setParamsByID values for unknown macros; allow ${MODEL_ID} in keys
- Build per-model expanded filters and alias map in NewPeerProxy
- Add RealPeerModelName to resolve aliases to base model IDs
- Update HasPeerModel to recognise setParamsByID aliases
- Rewrite model field in request body to base model ID when alias used
- Apply setParamsByID params in proxyInferenceHandler for peers
- Sync alias resolution to proxyOAIPostFormHandler and proxyGETModelHandler
fixes #697
1184 lines
37 KiB
Go
1184 lines
37 KiB
Go
package proxy
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"io"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"os"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/mostlygeek/llama-swap/event"
|
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
|
"github.com/tidwall/gjson"
|
|
"github.com/tidwall/sjson"
|
|
)
|
|
|
|
const (
|
|
PROFILE_SPLIT_CHAR = ":"
|
|
)
|
|
|
|
type proxyCtxKey string
|
|
|
|
type InflightCounter struct {
|
|
mu sync.Mutex
|
|
total int
|
|
}
|
|
|
|
func newInflightCounter() *InflightCounter {
|
|
return &InflightCounter{}
|
|
}
|
|
|
|
func (ic *InflightCounter) Current() int {
|
|
ic.mu.Lock()
|
|
total := ic.total
|
|
ic.mu.Unlock()
|
|
return total
|
|
}
|
|
|
|
func (ic *InflightCounter) Increment() int {
|
|
ic.mu.Lock()
|
|
ic.total++
|
|
total := ic.total
|
|
ic.mu.Unlock()
|
|
return total
|
|
}
|
|
|
|
func (ic *InflightCounter) Decrement() int {
|
|
ic.mu.Lock()
|
|
if ic.total > 0 {
|
|
ic.total--
|
|
}
|
|
total := ic.total
|
|
ic.mu.Unlock()
|
|
return total
|
|
}
|
|
|
|
type ProxyManager struct {
|
|
sync.Mutex
|
|
|
|
config config.Config
|
|
ginEngine *gin.Engine
|
|
|
|
// logging
|
|
proxyLogger *LogMonitor
|
|
upstreamLogger *LogMonitor
|
|
muxLogger *LogMonitor
|
|
|
|
metricsMonitor *metricsMonitor
|
|
|
|
processGroups map[string]*ProcessGroup
|
|
|
|
// matrix-based swap (mutually exclusive with processGroups)
|
|
matrix *Matrix
|
|
|
|
inFlightCounter *InflightCounter
|
|
|
|
// shutdown signaling
|
|
shutdownCtx context.Context
|
|
shutdownCancel context.CancelFunc
|
|
|
|
// version info
|
|
buildDate string
|
|
commit string
|
|
version string
|
|
|
|
// peer proxy see: #296, #433
|
|
peerProxy *PeerProxy
|
|
}
|
|
|
|
func New(proxyConfig config.Config) *ProxyManager {
|
|
// set up loggers
|
|
|
|
var muxLogger, upstreamLogger, proxyLogger *LogMonitor
|
|
switch proxyConfig.LogToStdout {
|
|
case config.LogToStdoutNone:
|
|
muxLogger = NewLogMonitorWriter(io.Discard)
|
|
upstreamLogger = NewLogMonitorWriter(io.Discard)
|
|
proxyLogger = NewLogMonitorWriter(io.Discard)
|
|
case config.LogToStdoutBoth:
|
|
muxLogger = NewLogMonitorWriter(os.Stdout)
|
|
upstreamLogger = NewLogMonitorWriter(muxLogger)
|
|
proxyLogger = NewLogMonitorWriter(muxLogger)
|
|
case config.LogToStdoutUpstream:
|
|
muxLogger = NewLogMonitorWriter(os.Stdout)
|
|
upstreamLogger = NewLogMonitorWriter(muxLogger)
|
|
proxyLogger = NewLogMonitorWriter(io.Discard)
|
|
default:
|
|
// same as config.LogToStdoutProxy
|
|
// helpful because some old tests create a config.Config directly and it
|
|
// may not have LogToStdout set explicitly
|
|
muxLogger = NewLogMonitorWriter(os.Stdout)
|
|
upstreamLogger = NewLogMonitorWriter(io.Discard)
|
|
proxyLogger = NewLogMonitorWriter(muxLogger)
|
|
}
|
|
|
|
if proxyConfig.LogRequests {
|
|
proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.")
|
|
}
|
|
|
|
switch strings.ToLower(strings.TrimSpace(proxyConfig.LogLevel)) {
|
|
case "debug":
|
|
proxyLogger.SetLogLevel(LevelDebug)
|
|
upstreamLogger.SetLogLevel(LevelDebug)
|
|
case "info":
|
|
proxyLogger.SetLogLevel(LevelInfo)
|
|
upstreamLogger.SetLogLevel(LevelInfo)
|
|
case "warn":
|
|
proxyLogger.SetLogLevel(LevelWarn)
|
|
upstreamLogger.SetLogLevel(LevelWarn)
|
|
case "error":
|
|
proxyLogger.SetLogLevel(LevelError)
|
|
upstreamLogger.SetLogLevel(LevelError)
|
|
default:
|
|
proxyLogger.SetLogLevel(LevelInfo)
|
|
upstreamLogger.SetLogLevel(LevelInfo)
|
|
}
|
|
|
|
// see: https://go.dev/src/time/format.go
|
|
timeFormats := map[string]string{
|
|
"ansic": time.ANSIC,
|
|
"unixdate": time.UnixDate,
|
|
"rubydate": time.RubyDate,
|
|
"rfc822": time.RFC822,
|
|
"rfc822z": time.RFC822Z,
|
|
"rfc850": time.RFC850,
|
|
"rfc1123": time.RFC1123,
|
|
"rfc1123z": time.RFC1123Z,
|
|
"rfc3339": time.RFC3339,
|
|
"rfc3339nano": time.RFC3339Nano,
|
|
"kitchen": time.Kitchen,
|
|
"stamp": time.Stamp,
|
|
"stampmilli": time.StampMilli,
|
|
"stampmicro": time.StampMicro,
|
|
"stampnano": time.StampNano,
|
|
}
|
|
|
|
if timeFormat, ok := timeFormats[strings.ToLower(strings.TrimSpace(proxyConfig.LogTimeFormat))]; ok {
|
|
proxyLogger.SetLogTimeFormat(timeFormat)
|
|
upstreamLogger.SetLogTimeFormat(timeFormat)
|
|
}
|
|
|
|
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
|
|
|
var maxMetrics int
|
|
if proxyConfig.MetricsMaxInMemory <= 0 {
|
|
maxMetrics = 1000 // Default fallback
|
|
} else {
|
|
maxMetrics = proxyConfig.MetricsMaxInMemory
|
|
}
|
|
|
|
peerProxy, err := NewPeerProxy(proxyConfig.Peers, proxyLogger)
|
|
if err != nil {
|
|
proxyLogger.Errorf("Disabling Peering. Failed to create proxy peers: %v", err)
|
|
peerProxy = nil
|
|
}
|
|
|
|
pm := &ProxyManager{
|
|
config: proxyConfig,
|
|
ginEngine: gin.New(),
|
|
|
|
proxyLogger: proxyLogger,
|
|
muxLogger: muxLogger,
|
|
upstreamLogger: upstreamLogger,
|
|
|
|
metricsMonitor: newMetricsMonitor(proxyLogger, maxMetrics, proxyConfig.CaptureBuffer),
|
|
|
|
processGroups: make(map[string]*ProcessGroup),
|
|
|
|
inFlightCounter: newInflightCounter(),
|
|
|
|
shutdownCtx: shutdownCtx,
|
|
shutdownCancel: shutdownCancel,
|
|
|
|
buildDate: "unknown",
|
|
commit: "abcd1234",
|
|
version: "0",
|
|
|
|
peerProxy: peerProxy,
|
|
}
|
|
|
|
// create either matrix or process groups (mutually exclusive)
|
|
if proxyConfig.Matrix != nil {
|
|
pm.matrix = NewMatrix(proxyConfig, proxyLogger, upstreamLogger)
|
|
} else {
|
|
for groupID := range proxyConfig.Groups {
|
|
processGroup := NewProcessGroup(groupID, proxyConfig, proxyLogger, upstreamLogger)
|
|
pm.processGroups[groupID] = processGroup
|
|
}
|
|
}
|
|
|
|
pm.setupGinEngine()
|
|
|
|
// run any startup hooks
|
|
if len(proxyConfig.Hooks.OnStartup.Preload) > 0 {
|
|
// do it in the background, don't block startup -- not sure if good idea yet
|
|
go func() {
|
|
discardWriter := &DiscardWriter{}
|
|
for _, preloadModelName := range proxyConfig.Hooks.OnStartup.Preload {
|
|
modelID, ok := proxyConfig.RealModelName(preloadModelName)
|
|
|
|
if !ok {
|
|
proxyLogger.Warnf("Preload model %s not found in config", preloadModelName)
|
|
continue
|
|
}
|
|
|
|
proxyLogger.Infof("Preloading model: %s", modelID)
|
|
|
|
var preloadErr error
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
|
|
if pm.matrix != nil {
|
|
preloadErr = pm.matrix.ProxyRequest(modelID, discardWriter, req)
|
|
} else {
|
|
processGroup, err := pm.swapProcessGroup(modelID)
|
|
if err != nil {
|
|
preloadErr = err
|
|
} else {
|
|
preloadErr = processGroup.ProxyRequest(modelID, discardWriter, req)
|
|
}
|
|
}
|
|
|
|
if preloadErr != nil {
|
|
event.Emit(ModelPreloadedEvent{
|
|
ModelName: modelID,
|
|
Success: false,
|
|
})
|
|
proxyLogger.Errorf("Failed to preload model %s: %v", modelID, preloadErr)
|
|
continue
|
|
} else {
|
|
event.Emit(ModelPreloadedEvent{
|
|
ModelName: modelID,
|
|
Success: true,
|
|
})
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
return pm
|
|
}
|
|
|
|
func (pm *ProxyManager) setupGinEngine() {
|
|
|
|
pm.ginEngine.Use(func(c *gin.Context) {
|
|
|
|
// don't log the Wake on Lan proxy health check
|
|
if c.Request.URL.Path == "/wol-health" {
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
// Start timer
|
|
start := time.Now()
|
|
|
|
// capture these because /upstream/:model rewrites them in c.Next()
|
|
clientIP := c.ClientIP()
|
|
method := c.Request.Method
|
|
path := c.Request.URL.Path
|
|
|
|
// Process request
|
|
c.Next()
|
|
|
|
// Stop timer
|
|
duration := time.Since(start)
|
|
|
|
statusCode := c.Writer.Status()
|
|
bodySize := c.Writer.Size()
|
|
|
|
pm.proxyLogger.Infof("Request %s \"%s %s %s\" %d %d \"%s\" %v",
|
|
clientIP,
|
|
method,
|
|
path,
|
|
c.Request.Proto,
|
|
statusCode,
|
|
bodySize,
|
|
c.Request.UserAgent(),
|
|
duration,
|
|
)
|
|
})
|
|
|
|
// see: issue: #81, #77 and #42 for CORS issues
|
|
// respond with permissive OPTIONS for any endpoint
|
|
pm.ginEngine.Use(func(c *gin.Context) {
|
|
if c.Request.Method == "OPTIONS" {
|
|
c.Header("Access-Control-Allow-Origin", "*")
|
|
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
|
|
|
// allow whatever the client requested by default
|
|
if headers := c.Request.Header.Get("Access-Control-Request-Headers"); headers != "" {
|
|
sanitized := SanitizeAccessControlRequestHeaderValues(headers)
|
|
c.Header("Access-Control-Allow-Headers", sanitized)
|
|
} else {
|
|
c.Header(
|
|
"Access-Control-Allow-Headers",
|
|
"Content-Type, Authorization, Accept, X-Requested-With",
|
|
)
|
|
}
|
|
c.Header("Access-Control-Max-Age", "86400")
|
|
c.AbortWithStatus(http.StatusNoContent)
|
|
return
|
|
}
|
|
c.Next()
|
|
})
|
|
|
|
// Set up routes using the Gin engine
|
|
// Protected routes use pm.apiKeyAuth() middleware
|
|
pm.ginEngine.POST("/v1/chat/completions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
|
pm.ginEngine.POST("/v1/responses", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
|
// Support legacy /v1/completions api, see issue #12
|
|
pm.ginEngine.POST("/v1/completions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
|
// Support anthropic /v1/messages (added https://github.com/ggml-org/llama.cpp/pull/17570)
|
|
pm.ginEngine.POST("/v1/messages", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
|
// Support anthropic count_tokens API (Also added in the above PR)
|
|
pm.ginEngine.POST("/v1/messages/count_tokens", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
|
|
|
// Support embeddings and reranking
|
|
pm.ginEngine.POST("/v1/embeddings", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
|
|
|
// llama-server's /reranking endpoint + aliases
|
|
pm.ginEngine.POST("/reranking", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
|
pm.ginEngine.POST("/rerank", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
|
pm.ginEngine.POST("/v1/rerank", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
|
pm.ginEngine.POST("/v1/reranking", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
|
|
|
// llama-server's /infill endpoint for code infilling
|
|
pm.ginEngine.POST("/infill", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
|
|
|
// llama-server's /completion endpoint
|
|
pm.ginEngine.POST("/completion", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
|
|
|
// Support audio/speech endpoint
|
|
pm.ginEngine.POST("/v1/audio/speech", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
|
pm.ginEngine.POST("/v1/audio/voices", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
|
pm.ginEngine.GET("/v1/audio/voices", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyGETModelHandler)
|
|
pm.ginEngine.POST("/v1/audio/transcriptions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyOAIPostFormHandler)
|
|
pm.ginEngine.POST("/v1/images/generations", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
|
pm.ginEngine.POST("/v1/images/edits", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyOAIPostFormHandler)
|
|
|
|
// sd.cpp /sdapi/v1 endpoints
|
|
pm.ginEngine.POST("/sdapi/v1/txt2img", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
|
pm.ginEngine.POST("/sdapi/v1/img2img", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
|
pm.ginEngine.GET("/sdapi/v1/loras", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyGETModelHandler)
|
|
|
|
pm.ginEngine.GET("/v1/models", pm.apiKeyAuth(), pm.listModelsHandler)
|
|
|
|
// in proxymanager_loghandlers.go
|
|
pm.ginEngine.GET("/logs", pm.apiKeyAuth(), pm.sendLogsHandlers)
|
|
pm.ginEngine.GET("/logs/stream", pm.apiKeyAuth(), pm.streamLogsHandler)
|
|
pm.ginEngine.GET("/logs/stream/*logMonitorID", pm.apiKeyAuth(), pm.streamLogsHandler)
|
|
|
|
/**
|
|
* User Interface Endpoints
|
|
*/
|
|
pm.ginEngine.GET("/", func(c *gin.Context) {
|
|
c.Redirect(http.StatusFound, "/ui")
|
|
})
|
|
|
|
pm.ginEngine.GET("/upstream", func(c *gin.Context) {
|
|
c.Redirect(http.StatusFound, "/ui/models")
|
|
})
|
|
pm.ginEngine.Any("/upstream/*upstreamPath", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyToUpstream)
|
|
pm.ginEngine.GET("/unload", pm.apiKeyAuth(), pm.unloadAllModelsHandler)
|
|
pm.ginEngine.GET("/running", pm.apiKeyAuth(), pm.listRunningProcessesHandler)
|
|
pm.ginEngine.GET("/health", func(c *gin.Context) {
|
|
c.String(http.StatusOK, "OK")
|
|
})
|
|
|
|
// see cmd/wol-proxy/wol-proxy.go, not logged
|
|
pm.ginEngine.GET("/wol-health", func(c *gin.Context) {
|
|
c.String(http.StatusOK, "OK")
|
|
})
|
|
|
|
pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) {
|
|
if data, err := reactStaticFS.ReadFile("ui_dist/favicon.ico"); err == nil {
|
|
c.Data(http.StatusOK, "image/x-icon", data)
|
|
} else {
|
|
c.String(http.StatusInternalServerError, err.Error())
|
|
}
|
|
})
|
|
|
|
reactFS, err := GetReactFS()
|
|
if err != nil {
|
|
pm.proxyLogger.Errorf("Failed to load React filesystem: %v", err)
|
|
} else {
|
|
// Serve files with compression support under /ui/*
|
|
// This handler checks for pre-compressed .br and .gz files
|
|
pm.ginEngine.GET("/ui/*filepath", func(c *gin.Context) {
|
|
filepath := strings.TrimPrefix(c.Param("filepath"), "/")
|
|
// Default to index.html for directory-like paths
|
|
if filepath == "" {
|
|
filepath = "index.html"
|
|
}
|
|
|
|
ServeCompressedFile(reactFS, c.Writer, c.Request, filepath)
|
|
})
|
|
|
|
// Serve SPA for UI under /ui/* - fallback to index.html for client-side routing
|
|
pm.ginEngine.NoRoute(func(c *gin.Context) {
|
|
if !strings.HasPrefix(c.Request.URL.Path, "/ui") {
|
|
c.AbortWithStatus(http.StatusNotFound)
|
|
return
|
|
}
|
|
|
|
// Check if this looks like a file request (has extension)
|
|
path := c.Request.URL.Path
|
|
if strings.Contains(path, ".") && !strings.HasSuffix(path, "/") {
|
|
// This was likely a file request that wasn't found
|
|
c.AbortWithStatus(http.StatusNotFound)
|
|
return
|
|
}
|
|
|
|
// Serve index.html for SPA routing
|
|
ServeCompressedFile(reactFS, c.Writer, c.Request, "index.html")
|
|
})
|
|
}
|
|
|
|
// see: proxymanager_api.go
|
|
// add API handler functions
|
|
addApiHandlers(pm)
|
|
|
|
// Disable console color for testing
|
|
gin.DisableConsoleColor()
|
|
}
|
|
|
|
func (pm *ProxyManager) trackInflight() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
event.Emit(InFlightRequestsEvent{Total: pm.inFlightCounter.Increment()})
|
|
defer event.Emit(InFlightRequestsEvent{Total: pm.inFlightCounter.Decrement()})
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// ServeHTTP implements http.Handler interface
|
|
func (pm *ProxyManager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
pm.ginEngine.ServeHTTP(w, r)
|
|
}
|
|
|
|
// StopProcesses acquires a lock and stops all running upstream processes.
|
|
// This is the public method safe for concurrent calls.
|
|
// Unlike Shutdown, this method only stops the processes but doesn't perform
|
|
// a complete shutdown, allowing for process replacement without full termination.
|
|
func (pm *ProxyManager) StopProcesses(strategy StopStrategy) {
|
|
pm.Lock()
|
|
defer pm.Unlock()
|
|
|
|
if pm.matrix != nil {
|
|
pm.matrix.StopProcesses(strategy)
|
|
return
|
|
}
|
|
|
|
// stop Processes in parallel
|
|
var wg sync.WaitGroup
|
|
for _, processGroup := range pm.processGroups {
|
|
wg.Add(1)
|
|
go func(processGroup *ProcessGroup) {
|
|
defer wg.Done()
|
|
processGroup.StopProcesses(strategy)
|
|
}(processGroup)
|
|
}
|
|
|
|
wg.Wait()
|
|
}
|
|
|
|
// Shutdown stops all processes managed by this ProxyManager
|
|
func (pm *ProxyManager) Shutdown() {
|
|
pm.Lock()
|
|
defer pm.Unlock()
|
|
|
|
pm.proxyLogger.Debug("Shutdown() called in proxy manager")
|
|
|
|
if pm.matrix != nil {
|
|
pm.matrix.Shutdown()
|
|
pm.shutdownCancel()
|
|
return
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
// Send shutdown signal to all process in groups
|
|
for _, processGroup := range pm.processGroups {
|
|
wg.Add(1)
|
|
go func(processGroup *ProcessGroup) {
|
|
defer wg.Done()
|
|
processGroup.Shutdown()
|
|
}(processGroup)
|
|
}
|
|
wg.Wait()
|
|
pm.shutdownCancel()
|
|
}
|
|
|
|
func (pm *ProxyManager) swapProcessGroup(realModelName string) (*ProcessGroup, error) {
|
|
processGroup := pm.findGroupByModelName(realModelName)
|
|
if processGroup == nil {
|
|
return nil, fmt.Errorf("could not find process group for model %s", realModelName)
|
|
}
|
|
|
|
if processGroup.exclusive {
|
|
pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id)
|
|
for groupId, otherGroup := range pm.processGroups {
|
|
if groupId != processGroup.id && !otherGroup.persistent {
|
|
otherGroup.StopProcesses(StopWaitForInflightRequest)
|
|
}
|
|
}
|
|
}
|
|
|
|
return processGroup, nil
|
|
}
|
|
|
|
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
|
data := make([]gin.H, 0, len(pm.config.Models))
|
|
createdTime := time.Now().Unix()
|
|
|
|
newRecord := func(modelId string, modelConfig config.ModelConfig) gin.H {
|
|
record := gin.H{
|
|
"id": modelId,
|
|
"object": "model",
|
|
"created": createdTime,
|
|
"owned_by": "llama-swap",
|
|
}
|
|
|
|
if name := strings.TrimSpace(modelConfig.Name); name != "" {
|
|
record["name"] = name
|
|
}
|
|
if desc := strings.TrimSpace(modelConfig.Description); desc != "" {
|
|
record["description"] = desc
|
|
}
|
|
|
|
// Add metadata if present
|
|
if len(modelConfig.Metadata) > 0 {
|
|
record["meta"] = gin.H{
|
|
"llamaswap": modelConfig.Metadata,
|
|
}
|
|
}
|
|
return record
|
|
}
|
|
|
|
for id, modelConfig := range pm.config.Models {
|
|
if modelConfig.Unlisted {
|
|
continue
|
|
}
|
|
|
|
data = append(data, newRecord(id, modelConfig))
|
|
|
|
// Include aliases
|
|
if pm.config.IncludeAliasesInList {
|
|
for _, alias := range modelConfig.Aliases {
|
|
if alias := strings.TrimSpace(alias); alias != "" {
|
|
data = append(data, newRecord(alias, modelConfig))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if pm.peerProxy != nil {
|
|
for peerID, peer := range pm.peerProxy.ListPeers() {
|
|
// add peer models
|
|
for _, modelID := range peer.Models {
|
|
// Skip unlisted models if not showing them
|
|
record := newRecord(modelID, config.ModelConfig{
|
|
Name: fmt.Sprintf("%s: %s", peerID, modelID),
|
|
Metadata: map[string]any{
|
|
"peerID": peerID,
|
|
},
|
|
})
|
|
|
|
data = append(data, record)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Sort by the "id" key
|
|
sort.Slice(data, func(i, j int) bool {
|
|
si, _ := data[i]["id"].(string)
|
|
sj, _ := data[j]["id"].(string)
|
|
return si < sj
|
|
})
|
|
|
|
// Set CORS headers if origin exists
|
|
if origin := c.GetHeader("Origin"); origin != "" {
|
|
c.Header("Access-Control-Allow-Origin", origin)
|
|
}
|
|
|
|
// Use gin's JSON method which handles content-type and encoding
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"object": "list",
|
|
"data": data,
|
|
})
|
|
}
|
|
|
|
// findModelInPath searches for a valid model name in a path with slashes.
|
|
// It iteratively builds up path segments until it finds a matching model.
|
|
// Returns: (searchModelName, realModelName, remainingPath, found)
|
|
// Example: "/author/model/endpoint" with model "author/model" -> ("author/model", "author/model", "/endpoint", true)
|
|
func (pm *ProxyManager) findModelInPath(path string) (searchName string, realName string, remainingPath string, found bool) {
|
|
parts := strings.Split(strings.TrimSpace(path), "/")
|
|
searchModelName := ""
|
|
|
|
for i, part := range parts {
|
|
if part == "" {
|
|
continue
|
|
}
|
|
|
|
if searchModelName == "" {
|
|
searchModelName = part
|
|
} else {
|
|
searchModelName = searchModelName + "/" + part
|
|
}
|
|
|
|
if modelID, ok := pm.config.RealModelName(searchModelName); ok {
|
|
return searchModelName, modelID, "/" + strings.Join(parts[i+1:], "/"), true
|
|
}
|
|
}
|
|
|
|
return "", "", "", false
|
|
}
|
|
|
|
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
|
upstreamPath := c.Param("upstreamPath")
|
|
|
|
searchModelName, modelID, remainingPath, modelFound := pm.findModelInPath(upstreamPath)
|
|
|
|
if !modelFound {
|
|
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
|
|
return
|
|
}
|
|
|
|
// Redirect /upstream/modelname to /upstream/modelname/ for URL consistency.
|
|
// This ensures relative URLs in upstream responses resolve correctly and
|
|
// provides canonical URL form. Uses 308 for POST/PUT/etc to preserve the
|
|
// HTTP method (301 would downgrade to GET).
|
|
if remainingPath == "/" && !strings.HasSuffix(upstreamPath, "/") {
|
|
newPath := "/upstream/" + searchModelName + "/"
|
|
if c.Request.URL.RawQuery != "" {
|
|
newPath += "?" + c.Request.URL.RawQuery
|
|
}
|
|
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead {
|
|
c.Redirect(http.StatusMovedPermanently, newPath)
|
|
} else {
|
|
c.Redirect(http.StatusPermanentRedirect, newPath)
|
|
}
|
|
return
|
|
}
|
|
|
|
var handler func(string, http.ResponseWriter, *http.Request) error
|
|
if pm.matrix != nil {
|
|
handler = pm.matrix.ProxyRequest
|
|
} else {
|
|
processGroup, err := pm.swapProcessGroup(modelID)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
|
return
|
|
}
|
|
handler = processGroup.ProxyRequest
|
|
}
|
|
|
|
// rewrite the path
|
|
originalPath := c.Request.URL.Path
|
|
c.Request.URL.Path = remainingPath
|
|
|
|
// attempt to record metrics if it is a POST request
|
|
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
|
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, handler); err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
|
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", modelID, originalPath)
|
|
return
|
|
}
|
|
} else {
|
|
if err := handler(modelID, c.Writer, c.Request); err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
|
pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", modelID, originalPath)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) {
|
|
bodyBytes, err := io.ReadAll(c.Request.Body)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
|
|
return
|
|
}
|
|
|
|
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
|
if requestedModel == "" {
|
|
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
|
return
|
|
}
|
|
|
|
// Look for a matching local model first
|
|
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
|
|
|
|
modelID, found := pm.config.RealModelName(requestedModel)
|
|
if found {
|
|
var localHandler func(string, http.ResponseWriter, *http.Request) error
|
|
if pm.matrix != nil {
|
|
localHandler = pm.matrix.ProxyRequest
|
|
} else {
|
|
processGroup, err := pm.swapProcessGroup(modelID)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
|
return
|
|
}
|
|
localHandler = processGroup.ProxyRequest
|
|
}
|
|
|
|
// issue #69 allow custom model names to be sent to upstream
|
|
useModelName := pm.config.Models[modelID].UseModelName
|
|
if useModelName != "" {
|
|
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
|
|
return
|
|
}
|
|
}
|
|
|
|
// issue #174 strip parameters from the JSON body
|
|
stripParams, err := pm.config.Models[modelID].Filters.SanitizedStripParams()
|
|
if err != nil { // just log it and continue
|
|
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[modelID].Filters.StripParams, err.Error())
|
|
} else {
|
|
for _, param := range stripParams {
|
|
pm.proxyLogger.Debugf("<%s> stripping param: %s", modelID, param)
|
|
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// issue #453 set/override parameters in the JSON body
|
|
setParams, setParamKeys := pm.config.Models[modelID].Filters.SanitizedSetParams()
|
|
for _, key := range setParamKeys {
|
|
pm.proxyLogger.Debugf("<%s> setting param: %s", modelID, key)
|
|
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
|
|
return
|
|
}
|
|
}
|
|
|
|
// setParamsByID: set params based on the requested model ID (runs after setParams, can override it)
|
|
setParamsByIDParams, setParamsByIDKeys := pm.config.Models[modelID].Filters.SanitizedSetParamsByID(requestedModel)
|
|
for _, key := range setParamsByIDKeys {
|
|
pm.proxyLogger.Debugf("<%s> setting param by id: %s", requestedModel, key)
|
|
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParamsByIDParams[key])
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
|
|
return
|
|
}
|
|
}
|
|
|
|
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
|
nextHandler = localHandler
|
|
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
|
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
|
|
|
// Resolve alias to base model ID (e.g. "model_a:nothink" → "model_a")
|
|
realPeerModelID, _ := pm.peerProxy.RealPeerModelName(requestedModel)
|
|
modelID = realPeerModelID
|
|
|
|
// Rewrite model field to base model ID so the peer server receives the correct name
|
|
if realPeerModelID != requestedModel {
|
|
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", realPeerModelID)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error rewriting model name in JSON")
|
|
return
|
|
}
|
|
}
|
|
|
|
// issue #453 apply filters for peer requests
|
|
peerFilters := pm.peerProxy.GetPeerFilters(requestedModel)
|
|
|
|
// Apply stripParams - remove specified parameters from request
|
|
stripParams := peerFilters.SanitizedStripParams()
|
|
for _, param := range stripParams {
|
|
pm.proxyLogger.Debugf("<%s> stripping param: %s", requestedModel, param)
|
|
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stripping parameter %s from request", param))
|
|
return
|
|
}
|
|
}
|
|
|
|
// Apply setParams - set/override specified parameters in request
|
|
setParams, setParamKeys := peerFilters.SanitizedSetParams()
|
|
for _, key := range setParamKeys {
|
|
pm.proxyLogger.Debugf("<%s> setting param: %s", requestedModel, key)
|
|
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
|
|
return
|
|
}
|
|
}
|
|
|
|
// setParamsByID: set params based on the requested model ID (runs after setParams, can override it)
|
|
setParamsByIDParams, setParamsByIDKeys := peerFilters.SanitizedSetParamsByID(requestedModel)
|
|
for _, key := range setParamsByIDKeys {
|
|
pm.proxyLogger.Debugf("<%s> setting param by id: %s", requestedModel, key)
|
|
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParamsByIDParams[key])
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
|
|
return
|
|
}
|
|
}
|
|
|
|
nextHandler = pm.peerProxy.ProxyRequest
|
|
}
|
|
|
|
if nextHandler == nil {
|
|
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable inference handler for %s", requestedModel))
|
|
return
|
|
}
|
|
|
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
|
|
|
// dechunk it as we already have all the body bytes see issue #11
|
|
c.Request.Header.Del("transfer-encoding")
|
|
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
|
|
c.Request.ContentLength = int64(len(bodyBytes))
|
|
|
|
// issue #366 extract values that downstream handlers may need
|
|
isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool()
|
|
ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming)
|
|
ctx = context.WithValue(ctx, proxyCtxKey("model"), modelID)
|
|
c.Request = c.Request.WithContext(ctx)
|
|
|
|
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
|
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, nextHandler); err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
|
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request model %s", modelID)
|
|
return
|
|
}
|
|
} else {
|
|
if err := nextHandler(modelID, c.Writer, c.Request); err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
|
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
|
// Parse multipart form
|
|
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
|
|
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
|
return
|
|
}
|
|
|
|
// Get model parameter from the form
|
|
requestedModel := c.Request.FormValue("model")
|
|
if requestedModel == "" {
|
|
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' parameter in form data")
|
|
return
|
|
}
|
|
|
|
// Look for a matching local model first, then check peers
|
|
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
|
|
var useModelName string
|
|
|
|
modelID, found := pm.config.RealModelName(requestedModel)
|
|
if found {
|
|
if pm.matrix != nil {
|
|
nextHandler = pm.matrix.ProxyRequest
|
|
} else {
|
|
processGroup, err := pm.swapProcessGroup(modelID)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
|
return
|
|
}
|
|
nextHandler = processGroup.ProxyRequest
|
|
}
|
|
|
|
useModelName = pm.config.Models[modelID].UseModelName
|
|
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
|
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
|
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
|
realPeerModelID, _ := pm.peerProxy.RealPeerModelName(requestedModel)
|
|
modelID = realPeerModelID
|
|
if realPeerModelID != requestedModel {
|
|
useModelName = realPeerModelID // rewrite model field in reconstructed form
|
|
}
|
|
nextHandler = pm.peerProxy.ProxyRequest
|
|
}
|
|
|
|
if nextHandler == nil {
|
|
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable handler for %s", requestedModel))
|
|
return
|
|
}
|
|
|
|
// We need to reconstruct the multipart form in any case since the body is consumed
|
|
// Create a new buffer for the reconstructed request
|
|
var requestBuffer bytes.Buffer
|
|
multipartWriter := multipart.NewWriter(&requestBuffer)
|
|
|
|
// Copy all form values
|
|
for key, values := range c.Request.MultipartForm.Value {
|
|
for _, value := range values {
|
|
fieldValue := value
|
|
// If this is the model field and we have a profile, use just the model name
|
|
if key == "model" {
|
|
// # issue #69 allow custom model names to be sent to upstream
|
|
if useModelName != "" {
|
|
fieldValue = useModelName
|
|
} else {
|
|
fieldValue = requestedModel
|
|
}
|
|
}
|
|
field, err := multipartWriter.CreateFormField(key)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field")
|
|
return
|
|
}
|
|
if _, err = field.Write([]byte(fieldValue)); err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error writing form field")
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// Copy all files from the original request
|
|
for key, fileHeaders := range c.Request.MultipartForm.File {
|
|
for _, fileHeader := range fileHeaders {
|
|
formFile, err := multipartWriter.CreateFormFile(key, fileHeader.Filename)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form file")
|
|
return
|
|
}
|
|
|
|
file, err := fileHeader.Open()
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error opening uploaded file")
|
|
return
|
|
}
|
|
|
|
if _, err = io.Copy(formFile, file); err != nil {
|
|
file.Close()
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error copying file data")
|
|
return
|
|
}
|
|
file.Close()
|
|
}
|
|
}
|
|
|
|
// Close the multipart writer to finalize the form
|
|
if err := multipartWriter.Close(); err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error finalizing multipart form")
|
|
return
|
|
}
|
|
|
|
// Create a new request with the reconstructed form data
|
|
modifiedReq, err := http.NewRequestWithContext(
|
|
c.Request.Context(),
|
|
c.Request.Method,
|
|
c.Request.URL.String(),
|
|
&requestBuffer,
|
|
)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error creating modified request")
|
|
return
|
|
}
|
|
|
|
// Copy the headers from the original request
|
|
modifiedReq.Header = c.Request.Header.Clone()
|
|
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
|
|
|
|
// set the content length of the body
|
|
modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len()))
|
|
modifiedReq.ContentLength = int64(requestBuffer.Len())
|
|
|
|
// Use the modified request for proxying
|
|
if err := nextHandler(modelID, c.Writer, modifiedReq); err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
|
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
|
|
return
|
|
}
|
|
}
|
|
|
|
func (pm *ProxyManager) proxyGETModelHandler(c *gin.Context) {
|
|
requestedModel := c.Query("model")
|
|
if requestedModel == "" {
|
|
pm.sendErrorResponse(c, http.StatusBadRequest, "missing required 'model' query parameter")
|
|
return
|
|
}
|
|
|
|
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
|
|
var modelID string
|
|
|
|
if realModelID, found := pm.config.RealModelName(requestedModel); found {
|
|
modelID = realModelID
|
|
if pm.matrix != nil {
|
|
nextHandler = pm.matrix.ProxyRequest
|
|
} else {
|
|
processGroup, err := pm.swapProcessGroup(realModelID)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
|
return
|
|
}
|
|
nextHandler = processGroup.ProxyRequest
|
|
}
|
|
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
|
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
|
realPeerModelID, _ := pm.peerProxy.RealPeerModelName(requestedModel)
|
|
modelID = realPeerModelID
|
|
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
|
nextHandler = pm.peerProxy.ProxyRequest
|
|
}
|
|
|
|
if nextHandler == nil {
|
|
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable handler for %s", requestedModel))
|
|
return
|
|
}
|
|
|
|
if err := nextHandler(modelID, c.Writer, c.Request); err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
|
pm.proxyLogger.Errorf("Error Proxying GET Request for model %s", modelID)
|
|
return
|
|
}
|
|
}
|
|
|
|
func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) {
|
|
acceptHeader := c.GetHeader("Accept")
|
|
|
|
if strings.Contains(acceptHeader, "application/json") {
|
|
c.JSON(statusCode, gin.H{"error": message})
|
|
} else {
|
|
c.String(statusCode, message)
|
|
}
|
|
}
|
|
|
|
// apiKeyAuth returns a middleware that validates API keys if configured.
|
|
// Returns a pass-through handler if no API keys are configured.
|
|
func (pm *ProxyManager) apiKeyAuth() gin.HandlerFunc {
|
|
if len(pm.config.RequiredAPIKeys) == 0 {
|
|
return func(c *gin.Context) { c.Next() }
|
|
}
|
|
|
|
return func(c *gin.Context) {
|
|
xApiKey := c.GetHeader("x-api-key")
|
|
|
|
var bearerKey string
|
|
var basicKey string
|
|
if auth := c.GetHeader("Authorization"); auth != "" {
|
|
if strings.HasPrefix(auth, "Bearer ") {
|
|
bearerKey = strings.TrimPrefix(auth, "Bearer ")
|
|
} else if strings.HasPrefix(auth, "Basic ") {
|
|
// Basic Auth: base64(username:password), password is the API key
|
|
encoded := strings.TrimPrefix(auth, "Basic ")
|
|
if decoded, err := base64.StdEncoding.DecodeString(encoded); err == nil {
|
|
parts := strings.SplitN(string(decoded), ":", 2)
|
|
if len(parts) == 2 {
|
|
basicKey = parts[1] // password is the API key
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Use first key found: Basic, then Bearer, then x-api-key
|
|
var providedKey string
|
|
if basicKey != "" {
|
|
providedKey = basicKey
|
|
} else if bearerKey != "" {
|
|
providedKey = bearerKey
|
|
} else {
|
|
providedKey = xApiKey
|
|
}
|
|
|
|
// Validate key
|
|
valid := false
|
|
for _, key := range pm.config.RequiredAPIKeys {
|
|
if providedKey == key {
|
|
valid = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !valid {
|
|
c.Header("WWW-Authenticate", `Basic realm="llama-swap"`)
|
|
pm.sendErrorResponse(c, http.StatusUnauthorized, "unauthorized: invalid or missing API key")
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
// Strip auth headers to prevent leakage to upstream
|
|
c.Request.Header.Del("Authorization")
|
|
c.Request.Header.Del("x-api-key")
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
|
|
pm.StopProcesses(StopImmediately)
|
|
c.String(http.StatusOK, "OK")
|
|
}
|
|
|
|
func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
|
|
context.Header("Content-Type", "application/json")
|
|
runningProcesses := make([]gin.H, 0) // Default to an empty response.
|
|
|
|
if pm.matrix != nil {
|
|
for _, modelID := range pm.matrix.RunningModels() {
|
|
if process, ok := pm.matrix.GetProcess(modelID); ok {
|
|
runningProcesses = append(runningProcesses, gin.H{
|
|
"model": process.ID,
|
|
"state": process.state,
|
|
"cmd": process.config.Cmd,
|
|
"proxy": process.config.Proxy,
|
|
"ttl": process.config.UnloadAfter,
|
|
"name": process.config.Name,
|
|
"description": process.config.Description,
|
|
})
|
|
}
|
|
}
|
|
} else {
|
|
for _, processGroup := range pm.processGroups {
|
|
for _, process := range processGroup.processes {
|
|
if process.CurrentState() == StateReady {
|
|
runningProcesses = append(runningProcesses, gin.H{
|
|
"model": process.ID,
|
|
"state": process.state,
|
|
"cmd": process.config.Cmd,
|
|
"proxy": process.config.Proxy,
|
|
"ttl": process.config.UnloadAfter,
|
|
"name": process.config.Name,
|
|
"description": process.config.Description,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Put the results under the `running` key.
|
|
response := gin.H{
|
|
"running": runningProcesses,
|
|
}
|
|
|
|
context.JSON(http.StatusOK, response) // Always return 200 OK
|
|
}
|
|
|
|
func (pm *ProxyManager) findGroupByModelName(modelName string) *ProcessGroup {
|
|
for _, group := range pm.processGroups {
|
|
if group.HasMember(modelName) {
|
|
return group
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (pm *ProxyManager) SetVersion(buildDate string, commit string, version string) {
|
|
pm.Lock()
|
|
defer pm.Unlock()
|
|
pm.buildDate = buildDate
|
|
pm.commit = commit
|
|
pm.version = version
|
|
}
|