Compare commits

...

4 Commits

Author SHA1 Message Date
Benson Wong 04955cee1d proxy: add debug logging to matrix solver decisions
Add set name, DSL expression, and eviction cost to SolveResult and
log solver decisions in ProxyRequest for troubleshooting.

- Add DSL field to ExpandedSet, populated during ValidateMatrix
- Add SetName, DSL, TotalCost to SolveResult
- Log eviction, cold start, and already-running cases via proxyLogger
2026-04-12 04:24:43 +00:00
Benson Wong c492fa8ee3 config.example.yaml: uncomment matrix section 2026-04-12 00:23:52 +00:00
Benson Wong f0fd2a9765 proxy: rename matrix aliases to map, enforce map-only IDs in sets and evict_costs
Rename the matrix `aliases` field to `map` and require that `sets` DSL
expressions and `evict_costs` keys only use IDs defined in the map.
Real model names can no longer be used directly — all identifiers must
resolve through the map, removing ambiguity from the DSL.

- Rename Aliases field to Map (yaml tag "map")
- Remove resolveIdentifier closure, use direct matrix.Map lookups
- Remove collision check (no longer relevant with map-only resolution)
- Update error messages to "unknown map ID"
- Update config example, JSON schema, and tests
2026-04-12 00:07:25 +00:00
Benson Wong 4a75fc62fc proxy: add swap matrix with solver-based model swapping
Add a new "matrix" configuration as an alternative to groups. The matrix
uses a solver to find the cheapest way to make a requested model available
by minimizing eviction costs across defined concurrent model sets.

- add DSL parser for set expressions with & (AND), | (OR), (), +ref
- add MatrixConfig structs, validation, and topological sort for +ref
- add MatrixSolver with cost-minimizing swap decisions
- add Matrix runtime integrating solver with Process lifecycle
- integrate matrix into ProxyManager with if-branches at all endpoints
- update config.example.yaml and config-schema.json with matrix schema
- config enforces groups XOR matrix (cannot use both)
2026-04-12 00:06:10 +00:00
11 changed files with 1973 additions and 75 deletions
+35
View File
@@ -319,6 +319,41 @@
},
"description": "A dictionary of group settings. Provides advanced controls over model swapping behaviour. Model IDs must be defined in models. A model can only be a member of one group. Behaviour controlled via swap, exclusive, persistent."
},
"matrix": {
"type": "object",
"description": "Solver-based alternative to groups. Declares valid combinations of concurrent models. The solver minimizes eviction cost when swapping. A config must use either groups or matrix, not both.",
"required": [
"sets"
],
"properties": {
"map": {
"type": "object",
"description": "Short names for models. Keys must be alphanumeric, 1-8 characters, and must not collide with model names. All sets and evict_costs must use these IDs.",
"additionalProperties": {
"type": "string"
},
"propertyNames": {
"pattern": "^[a-zA-Z0-9]{1,8}$"
}
},
"evict_costs": {
"type": "object",
"description": "Relative cost of evicting a running model. Models not listed default to 1. Values must be positive integers.",
"additionalProperties": {
"type": "integer",
"minimum": 1
}
},
"sets": {
"type": "object",
"description": "Named sets of concurrent model combinations. Values are DSL strings using & (AND), | (OR), () (grouping), and +ref (inline another set). Definition order is used for tie-breaking.",
"additionalProperties": {
"type": "string"
}
}
},
"additionalProperties": false
},
"hooks": {
"type": "object",
"properties": {
+74
View File
@@ -393,6 +393,80 @@ groups:
- "forever-modelB"
- "forever-modelc"
# =============================================================================
# matrix: solver-based alternative to groups
# =============================================================================
#
# A config must use either groups or matrix, not both.
#
# The matrix declares valid combinations of models that can run concurrently.
# When a model is requested, the solver finds the cheapest way to make it
# available by evicting as few (and least costly) running models as possible.
#
# Solver behavior:
# 1. Request arrives for model X
# 2. If X is already running, forward immediately. Done.
# 3. Find all sets containing X
# 4. For each candidate set, compute cost: sum of evict_costs for
# every running model NOT in that set
# 5. Pick lowest cost candidate. Ties broken by definition order.
# 6. Evict what needs to stop. Start X. Forward request.
#
# Subset semantics: a set [a, b, c] means any subset is valid.
# Only the requested model is started — others are not preloaded.
#
# A model not appearing in any set can only run alone.
#
matrix:
# map: short names for models (alphanumeric, 1-8 chars)
# - required for sets and evict_costs settings
# - each entry is a short name to a model ID
# - used to keep set DSL logic short and easier to read
map:
g: gemma-model
q: qwen-model
m: mistral-model
v: voxtral-model
e: reranker-model
L: llama-70B
sd: stable-diffusion
# evict_costs: relative cost of losing a running model (default: 1)
evict_costs:
v: 50 # vllm backend, slow cold start
L: 30 # 70B weights, slow to load
# sets: named sets of concurrent model combinations
# Values are DSL strings with operators:
# & AND (models run together)
# | OR (alternatives)
# () grouping
# +ref inline another set's expression
#
# Expansion examples:
# "L" → [L]
# "a & b" → [a, b]
# "a | b" → [a], [b]
# "(a | b) & c" → [a, c], [b, c]
# "(a | b) & (c | d)" → [a,c], [a,d], [b,c], [b,d]
# "+llms & v" → expands llms inline, then applies & v
sets:
# LLM + TTS: switching between g/q/m won't evict v
# expands to: [g,v], [q,v], [m,v]
standard: "(g | q | m) & v"
# LLM + TTS + reranker
# expands to: [g,v,e], [q,v,e]
with_rerank: "(g | q) & v & e"
# LLM + image generation, no TTS
# expands to: [g,sd], [q,sd]
creative: "(g | q) & sd"
# 70B model uses all GPUs, can only run alone
# expands to: [L]
full: "L"
# hooks: a dictionary of event triggers and actions
# - optional, default: empty dictionary
# - the only supported hook is on_startup
+32 -13
View File
@@ -129,6 +129,12 @@ type Config struct {
Profiles map[string][]string `yaml:"profiles"`
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
// swap matrix: solver-based alternative to groups
Matrix *MatrixConfig `yaml:"matrix"`
// populated during validation when matrix is configured
ExpandedSets []ExpandedSet `yaml:"-"`
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
Macros MacroList `yaml:"macros"`
@@ -438,22 +444,35 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
config.Models[modelId] = modelConfig
}
config = AddDefaultGroupToConfig(config)
// groups XOR matrix
if config.Matrix != nil && len(config.Groups) > 0 {
return Config{}, fmt.Errorf("config cannot use both 'groups' and 'matrix'")
}
// Validate group members
memberUsage := make(map[string]string)
for groupID, groupConfig := range config.Groups {
prevSet := make(map[string]bool)
for _, member := range groupConfig.Members {
if _, found := prevSet[member]; found {
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
}
prevSet[member] = true
if config.Matrix != nil {
expandedSets, err := ValidateMatrix(*config.Matrix, config.Models)
if err != nil {
return Config{}, fmt.Errorf("matrix: %w", err)
}
config.ExpandedSets = expandedSets
} else {
config = AddDefaultGroupToConfig(config)
if existingGroup, exists := memberUsage[member]; exists {
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
// Validate group members
memberUsage := make(map[string]string)
for groupID, groupConfig := range config.Groups {
prevSet := make(map[string]bool)
for _, member := range groupConfig.Members {
if _, found := prevSet[member]; found {
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
}
prevSet[member] = true
if existingGroup, exists := memberUsage[member]; exists {
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
}
memberUsage[member] = groupID
}
memberUsage[member] = groupID
}
}
+222
View File
@@ -0,0 +1,222 @@
package config
import (
"fmt"
"regexp"
"sort"
"gopkg.in/yaml.v3"
)
var mapKeyPattern = regexp.MustCompile(`^[a-zA-Z0-9]{1,8}$`)
// MatrixConfig represents the swap matrix configuration block.
type MatrixConfig struct {
Map map[string]string `yaml:"map"`
EvictCosts map[string]int `yaml:"evict_costs"`
Sets OrderedSets `yaml:"sets"`
}
// SetEntry is a single named set with its DSL expression.
type SetEntry struct {
Name string
DSL string
}
// OrderedSets preserves YAML definition order of sets (used for tie-breaking).
type OrderedSets []SetEntry
func (os *OrderedSets) UnmarshalYAML(value *yaml.Node) error {
if value.Kind != yaml.MappingNode {
return fmt.Errorf("sets must be a mapping")
}
entries := make([]SetEntry, 0, len(value.Content)/2)
for i := 0; i < len(value.Content); i += 2 {
keyNode := value.Content[i]
valueNode := value.Content[i+1]
var name string
if err := keyNode.Decode(&name); err != nil {
return fmt.Errorf("failed to decode set name: %w", err)
}
var dsl string
if err := valueNode.Decode(&dsl); err != nil {
return fmt.Errorf("failed to decode DSL for set %q: %w", name, err)
}
entries = append(entries, SetEntry{Name: name, DSL: dsl})
}
*os = entries
return nil
}
// ExpandedSet is one valid combination of concurrent models (real model names).
type ExpandedSet struct {
SetName string
DSL string
Models []string // real model names, sorted
}
// ValidateMatrix validates the matrix config and returns all expanded sets.
func ValidateMatrix(matrix MatrixConfig, models map[string]ModelConfig) ([]ExpandedSet, error) {
if len(matrix.Sets) == 0 {
return nil, fmt.Errorf("matrix must define at least one set")
}
// Validate map entries
if matrix.Map != nil {
for id, modelName := range matrix.Map {
if !mapKeyPattern.MatchString(id) {
return nil, fmt.Errorf("map key %q must be alphanumeric and 1-8 characters", id)
}
if _, exists := models[modelName]; !exists {
return nil, fmt.Errorf("map key %q references unknown model %q", id, modelName)
}
}
}
// Validate evict_costs
if matrix.EvictCosts != nil {
for key, cost := range matrix.EvictCosts {
if cost <= 0 {
return nil, fmt.Errorf("evict_cost for %q must be a positive integer, got %d", key, cost)
}
if _, ok := matrix.Map[key]; !ok {
return nil, fmt.Errorf("evict_costs: unknown map ID %q", key)
}
}
}
// Build dependency graph for +ref topological sort
setNames := make(map[string]bool)
for _, entry := range matrix.Sets {
setNames[entry.Name] = true
}
deps := make(map[string][]string) // setName -> set names it depends on
for _, entry := range matrix.Sets {
refs, err := extractRefs(entry.DSL)
if err != nil {
return nil, fmt.Errorf("set %q: %w", entry.Name, err)
}
for _, ref := range refs {
if !setNames[ref] {
return nil, fmt.Errorf("set %q references undefined set %q", entry.Name, ref)
}
}
deps[entry.Name] = refs
}
// Topological sort with cycle detection
order, err := topologicalSort(matrix.Sets, deps)
if err != nil {
return nil, err
}
// Expand sets in topological order
resolvedRefs := make(map[string][][]string) // set name -> expanded alias-level combos
var allExpanded []ExpandedSet
totalCombinations := 0
// Build ordered map for efficient lookup
setDSL := make(map[string]string)
for _, entry := range matrix.Sets {
setDSL[entry.Name] = entry.DSL
}
for _, name := range order {
dsl := setDSL[name]
combos, err := ParseAndExpandDSL(dsl, resolvedRefs)
if err != nil {
return nil, fmt.Errorf("set %q: %w", name, err)
}
resolvedRefs[name] = combos
// Resolve map IDs to real model names
for _, combo := range combos {
resolved := make([]string, len(combo))
for i, ident := range combo {
realName, ok := matrix.Map[ident]
if !ok {
return nil, fmt.Errorf("set %q: unknown map ID %q", name, ident)
}
resolved[i] = realName
}
sort.Strings(resolved)
allExpanded = append(allExpanded, ExpandedSet{
SetName: name,
DSL: dsl,
Models: resolved,
})
}
totalCombinations += len(combos)
if totalCombinations > maxDSLExpansions {
return nil, fmt.Errorf("total expanded combinations (%d) exceed limit of %d", totalCombinations, maxDSLExpansions)
}
}
return allExpanded, nil
}
// topologicalSort returns set names in dependency order.
// Returns an error if a cycle is detected.
func topologicalSort(sets OrderedSets, deps map[string][]string) ([]string, error) {
// States: 0 = unvisited, 1 = visiting, 2 = visited
state := make(map[string]int)
var order []string
var visit func(name string) error
visit = func(name string) error {
switch state[name] {
case 1:
return fmt.Errorf("circular reference detected involving set %q", name)
case 2:
return nil
}
state[name] = 1
for _, dep := range deps[name] {
if err := visit(dep); err != nil {
return err
}
}
state[name] = 2
order = append(order, name)
return nil
}
// Visit in definition order for deterministic output
for _, entry := range sets {
if state[entry.Name] == 0 {
if err := visit(entry.Name); err != nil {
return nil, err
}
}
}
return order, nil
}
// ResolvedEvictCosts returns a map of real model name -> evict cost,
// resolving map IDs. Models not listed default to 1.
func (m *MatrixConfig) ResolvedEvictCosts() map[string]int {
costs := make(map[string]int)
if m.EvictCosts == nil {
return costs
}
for key, cost := range m.EvictCosts {
// Resolve map ID if present
if realName, ok := m.Map[key]; ok {
costs[realName] = cost
} else {
costs[key] = cost
}
}
return costs
}
+372
View File
@@ -0,0 +1,372 @@
package config
import (
"fmt"
"sort"
"strings"
"unicode"
)
const maxDSLExpansions = 1000
// Token types for the DSL lexer
type tokenType int
const (
tokIdent tokenType = iota // model alias or name
tokAnd // &
tokOr // |
tokLParen // (
tokRParen // )
tokRef // +setName
tokEOF
)
type token struct {
typ tokenType
val string
}
// tokenize splits a DSL string into tokens.
func tokenize(input string) ([]token, error) {
var tokens []token
i := 0
runes := []rune(input)
for i < len(runes) {
ch := runes[i]
// skip whitespace
if unicode.IsSpace(ch) {
i++
continue
}
switch ch {
case '&':
tokens = append(tokens, token{tokAnd, "&"})
i++
case '|':
tokens = append(tokens, token{tokOr, "|"})
i++
case '(':
tokens = append(tokens, token{tokLParen, "("})
i++
case ')':
tokens = append(tokens, token{tokRParen, ")"})
i++
case '+':
// +ref: read the identifier that follows
i++
start := i
for i < len(runes) && isIdentChar(runes[i]) {
i++
}
if i == start {
return nil, fmt.Errorf("expected set name after '+' at position %d", start)
}
tokens = append(tokens, token{tokRef, string(runes[start:i])})
default:
if isIdentChar(ch) {
start := i
for i < len(runes) && isIdentChar(runes[i]) {
i++
}
tokens = append(tokens, token{tokIdent, string(runes[start:i])})
} else {
return nil, fmt.Errorf("unexpected character %q at position %d", ch, i)
}
}
}
tokens = append(tokens, token{tokEOF, ""})
return tokens, nil
}
func isIdentChar(ch rune) bool {
return unicode.IsLetter(ch) || unicode.IsDigit(ch) || ch == '_' || ch == '-' || ch == '.'
}
// AST node types
type dslNode interface {
dslNode()
}
type andNode struct {
children []dslNode
}
type orNode struct {
children []dslNode
}
type leafNode struct {
name string
}
type refNode struct {
setName string
}
func (andNode) dslNode() {}
func (orNode) dslNode() {}
func (leafNode) dslNode() {}
func (refNode) dslNode() {}
// parser holds state for recursive-descent parsing.
type parser struct {
tokens []token
pos int
}
func (p *parser) peek() token {
if p.pos < len(p.tokens) {
return p.tokens[p.pos]
}
return token{tokEOF, ""}
}
func (p *parser) next() token {
t := p.peek()
if t.typ != tokEOF {
p.pos++
}
return t
}
func (p *parser) expect(typ tokenType) (token, error) {
t := p.next()
if t.typ != typ {
return t, fmt.Errorf("expected token type %d, got %q", typ, t.val)
}
return t, nil
}
// Grammar:
//
// expr = andExpr
// andExpr = orExpr ('&' orExpr)*
// orExpr = atom ('|' atom)*
// atom = ident | '+' ident | '(' expr ')'
//
// & binds tighter than |, so "a | b & c" means "a | (b & c)"
func parse(tokens []token) (dslNode, error) {
p := &parser{tokens: tokens}
node, err := p.parseExpr()
if err != nil {
return nil, err
}
if p.peek().typ != tokEOF {
return nil, fmt.Errorf("unexpected token %q after expression", p.peek().val)
}
return node, nil
}
func (p *parser) parseExpr() (dslNode, error) {
return p.parseOrExpr()
}
func (p *parser) parseOrExpr() (dslNode, error) {
left, err := p.parseAndExpr()
if err != nil {
return nil, err
}
if p.peek().typ == tokOr {
children := []dslNode{left}
for p.peek().typ == tokOr {
p.next() // consume |
right, err := p.parseAndExpr()
if err != nil {
return nil, err
}
children = append(children, right)
}
return orNode{children: children}, nil
}
return left, nil
}
func (p *parser) parseAndExpr() (dslNode, error) {
left, err := p.parseAtom()
if err != nil {
return nil, err
}
if p.peek().typ == tokAnd {
children := []dslNode{left}
for p.peek().typ == tokAnd {
p.next() // consume &
right, err := p.parseAtom()
if err != nil {
return nil, err
}
children = append(children, right)
}
return andNode{children: children}, nil
}
return left, nil
}
func (p *parser) parseAtom() (dslNode, error) {
t := p.peek()
switch t.typ {
case tokIdent:
p.next()
return leafNode{name: t.val}, nil
case tokRef:
p.next()
return refNode{setName: t.val}, nil
case tokLParen:
p.next() // consume (
node, err := p.parseExpr()
if err != nil {
return nil, err
}
if _, err := p.expect(tokRParen); err != nil {
return nil, fmt.Errorf("missing closing parenthesis")
}
return node, nil
default:
return nil, fmt.Errorf("unexpected token %q", t.val)
}
}
// expand walks the AST and produces all combinations.
// resolvedRefs contains previously expanded sets for +ref resolution.
func expand(node dslNode, resolvedRefs map[string][][]string) ([][]string, error) {
switch n := node.(type) {
case leafNode:
return [][]string{{n.name}}, nil
case refNode:
expanded, ok := resolvedRefs[n.setName]
if !ok {
return nil, fmt.Errorf("unknown set reference +%s", n.setName)
}
// Return a copy
result := make([][]string, len(expanded))
for i, combo := range expanded {
result[i] = make([]string, len(combo))
copy(result[i], combo)
}
return result, nil
case orNode:
// Union of all children's expansions
var result [][]string
for _, child := range n.children {
childResult, err := expand(child, resolvedRefs)
if err != nil {
return nil, err
}
result = append(result, childResult...)
if len(result) > maxDSLExpansions {
return nil, fmt.Errorf("DSL expansion exceeded %d combinations", maxDSLExpansions)
}
}
return result, nil
case andNode:
// Cartesian product across children
result := [][]string{{}} // start with one empty combo
for _, child := range n.children {
childResult, err := expand(child, resolvedRefs)
if err != nil {
return nil, err
}
result = cartesianProduct(result, childResult)
if len(result) > maxDSLExpansions {
return nil, fmt.Errorf("DSL expansion exceeded %d combinations", maxDSLExpansions)
}
}
return result, nil
default:
return nil, fmt.Errorf("unknown node type %T", node)
}
}
// cartesianProduct computes the cartesian product of two sets of combinations.
func cartesianProduct(left, right [][]string) [][]string {
var result [][]string
for _, l := range left {
for _, r := range right {
combo := make([]string, 0, len(l)+len(r))
combo = append(combo, l...)
combo = append(combo, r...)
result = append(result, combo)
}
}
return result
}
// ParseAndExpandDSL tokenizes, parses, and expands a DSL string.
// resolvedRefs contains previously expanded sets for +ref inlining.
func ParseAndExpandDSL(dsl string, resolvedRefs map[string][][]string) ([][]string, error) {
dsl = strings.TrimSpace(dsl)
if dsl == "" {
return nil, fmt.Errorf("empty DSL expression")
}
tokens, err := tokenize(dsl)
if err != nil {
return nil, fmt.Errorf("tokenize: %w", err)
}
tree, err := parse(tokens)
if err != nil {
return nil, fmt.Errorf("parse: %w", err)
}
result, err := expand(tree, resolvedRefs)
if err != nil {
return nil, err
}
// Deduplicate models within each combination and sort for consistency
for i, combo := range result {
result[i] = dedupAndSort(combo)
}
return result, nil
}
// dedupAndSort removes duplicate entries and sorts alphabetically.
func dedupAndSort(items []string) []string {
seen := make(map[string]bool, len(items))
var unique []string
for _, item := range items {
if !seen[item] {
seen[item] = true
unique = append(unique, item)
}
}
sort.Strings(unique)
return unique
}
// extractRefs scans a DSL string for +ref tokens without full parsing.
// Used for building the dependency graph for topological sorting.
func extractRefs(dsl string) ([]string, error) {
tokens, err := tokenize(dsl)
if err != nil {
return nil, err
}
var refs []string
seen := make(map[string]bool)
for _, t := range tokens {
if t.typ == tokRef && !seen[t.val] {
seen[t.val] = true
refs = append(refs, t.val)
}
}
return refs, nil
}
+300
View File
@@ -0,0 +1,300 @@
package config
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDSL_Tokenize(t *testing.T) {
tests := []struct {
name string
input string
expect []token
errMsg string
}{
{
name: "single identifier",
input: "abc",
expect: []token{
{tokIdent, "abc"},
{tokEOF, ""},
},
},
{
name: "identifier with hyphens and dots",
input: "model-name.v2",
expect: []token{
{tokIdent, "model-name.v2"},
{tokEOF, ""},
},
},
{
name: "and expression",
input: "a & b",
expect: []token{
{tokIdent, "a"},
{tokAnd, "&"},
{tokIdent, "b"},
{tokEOF, ""},
},
},
{
name: "or expression",
input: "a | b",
expect: []token{
{tokIdent, "a"},
{tokOr, "|"},
{tokIdent, "b"},
{tokEOF, ""},
},
},
{
name: "parentheses",
input: "(a | b) & c",
expect: []token{
{tokLParen, "("},
{tokIdent, "a"},
{tokOr, "|"},
{tokIdent, "b"},
{tokRParen, ")"},
{tokAnd, "&"},
{tokIdent, "c"},
{tokEOF, ""},
},
},
{
name: "ref token",
input: "+llms & v",
expect: []token{
{tokRef, "llms"},
{tokAnd, "&"},
{tokIdent, "v"},
{tokEOF, ""},
},
},
{
name: "no whitespace",
input: "(a|b)&c",
expect: []token{
{tokLParen, "("},
{tokIdent, "a"},
{tokOr, "|"},
{tokIdent, "b"},
{tokRParen, ")"},
{tokAnd, "&"},
{tokIdent, "c"},
{tokEOF, ""},
},
},
{
name: "empty ref",
input: "+",
errMsg: "expected set name after '+'",
},
{
name: "invalid character",
input: "a @ b",
errMsg: "unexpected character",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tokens, err := tokenize(tt.input)
if tt.errMsg != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errMsg)
} else {
require.NoError(t, err)
assert.Equal(t, tt.expect, tokens)
}
})
}
}
func TestDSL_ParseAndExpand(t *testing.T) {
tests := []struct {
name string
dsl string
refs map[string][][]string
expect [][]string
errMsg string
}{
{
name: "single model",
dsl: "L",
expect: [][]string{{"L"}},
},
{
name: "two models with AND",
dsl: "a & b",
expect: [][]string{{"a", "b"}},
},
{
name: "two models with OR",
dsl: "a | b",
expect: [][]string{{"a"}, {"b"}},
},
{
name: "three models with OR",
dsl: "a | b | c",
expect: [][]string{{"a"}, {"b"}, {"c"}},
},
{
name: "cartesian product (a|b) & (c|d)",
dsl: "(a | b) & (c | d)",
expect: [][]string{
{"a", "c"},
{"a", "d"},
{"b", "c"},
{"b", "d"},
},
},
{
name: "three-way AND",
dsl: "a & b & c",
expect: [][]string{
{"a", "b", "c"},
},
},
{
name: "(g | q | m) & v",
dsl: "(g | q | m) & v",
expect: [][]string{
{"g", "v"},
{"q", "v"},
{"m", "v"},
},
},
{
name: "(g | q) & v & e",
dsl: "(g | q) & v & e",
expect: [][]string{
{"e", "g", "v"},
{"e", "q", "v"},
},
},
{
name: "precedence: a | b & c means a | (b & c)",
dsl: "a | b & c",
expect: [][]string{
{"a"},
{"b", "c"},
},
},
{
name: "+ref inlining",
dsl: "+llms & v",
refs: map[string][][]string{
"llms": {{"g"}, {"q"}, {"m"}},
},
expect: [][]string{
{"g", "v"},
{"q", "v"},
{"m", "v"},
},
},
{
name: "+ref chained",
dsl: "+with_tts & e",
refs: map[string][][]string{
"with_tts": {{"g", "v"}, {"q", "v"}, {"m", "v"}},
},
expect: [][]string{
{"e", "g", "v"},
{"e", "q", "v"},
{"e", "m", "v"},
},
},
{
name: "dedup within combination",
dsl: "a & a",
expect: [][]string{
{"a"},
},
},
{
name: "empty expression",
dsl: "",
errMsg: "empty DSL expression",
},
{
name: "unmatched open paren",
dsl: "(a | b",
errMsg: "missing closing parenthesis",
},
{
name: "unmatched close paren",
dsl: "a | b)",
errMsg: "unexpected token",
},
{
name: "unknown ref",
dsl: "+unknown",
errMsg: "unknown set reference +unknown",
},
{
name: "empty parens",
dsl: "()",
errMsg: "unexpected token",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
refs := tt.refs
if refs == nil {
refs = map[string][][]string{}
}
result, err := ParseAndExpandDSL(tt.dsl, refs)
if tt.errMsg != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errMsg)
} else {
require.NoError(t, err)
assert.Equal(t, tt.expect, result)
}
})
}
}
func TestDSL_ExpansionCap(t *testing.T) {
// Build an expression that would exceed 1000 combinations:
// (a1|a2|...|a32) & (b1|b2|...|b32) = 1024 combos
var aItems, bItems []string
for i := 0; i < 32; i++ {
aItems = append(aItems, fmt.Sprintf("a%d", i))
bItems = append(bItems, fmt.Sprintf("b%d", i))
}
dsl := fmt.Sprintf("(%s) & (%s)",
join(aItems, " | "),
join(bItems, " | "),
)
_, err := ParseAndExpandDSL(dsl, map[string][][]string{})
require.Error(t, err)
assert.Contains(t, err.Error(), "exceeded")
}
func TestDSL_ExtractRefs(t *testing.T) {
refs, err := extractRefs("+llms & v & +other")
require.NoError(t, err)
assert.Equal(t, []string{"llms", "other"}, refs)
refs, err = extractRefs("a & b")
require.NoError(t, err)
assert.Empty(t, refs)
}
func join(items []string, sep string) string {
result := ""
for i, item := range items {
if i > 0 {
result += sep
}
result += item
}
return result
}
+305
View File
@@ -0,0 +1,305 @@
package config
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func makeModels(names ...string) map[string]ModelConfig {
m := make(map[string]ModelConfig)
for _, name := range names {
m[name] = ModelConfig{Cmd: "echo " + name}
}
return m
}
func TestValidateMatrix_Basic(t *testing.T) {
models := makeModels("gemma", "qwen", "mistral", "voxtral", "llama70B")
matrix := MatrixConfig{
Map: map[string]string{
"g": "gemma",
"q": "qwen",
"m": "mistral",
"v": "voxtral",
"L": "llama70B",
},
EvictCosts: map[string]int{
"L": 30,
"v": 50,
},
Sets: OrderedSets{
{Name: "standard", DSL: "(g | q | m) & v"},
{Name: "full", DSL: "L"},
},
}
expanded, err := ValidateMatrix(matrix, models)
require.NoError(t, err)
// standard expands to [gemma,voxtral], [qwen,voxtral], [mistral,voxtral]
// full expands to [llama70B]
assert.Len(t, expanded, 4)
assert.Equal(t, "standard", expanded[0].SetName)
assert.Equal(t, []string{"gemma", "voxtral"}, expanded[0].Models)
assert.Equal(t, "standard", expanded[1].SetName)
assert.Equal(t, []string{"qwen", "voxtral"}, expanded[1].Models)
assert.Equal(t, "standard", expanded[2].SetName)
assert.Equal(t, []string{"mistral", "voxtral"}, expanded[2].Models)
assert.Equal(t, "full", expanded[3].SetName)
assert.Equal(t, []string{"llama70B"}, expanded[3].Models)
}
func TestValidateMatrix_WithRef(t *testing.T) {
models := makeModels("gemma", "qwen", "mistral", "voxtral", "reranker")
matrix := MatrixConfig{
Map: map[string]string{
"g": "gemma",
"q": "qwen",
"m": "mistral",
"v": "voxtral",
"e": "reranker",
},
Sets: OrderedSets{
{Name: "llms", DSL: "g | q | m"},
{Name: "with_tts", DSL: "+llms & v"},
{Name: "mega", DSL: "+with_tts & e"},
},
}
expanded, err := ValidateMatrix(matrix, models)
require.NoError(t, err)
// llms: [gemma], [qwen], [mistral]
// with_tts: [gemma,voxtral], [qwen,voxtral], [mistral,voxtral]
// mega: [gemma,reranker,voxtral], [qwen,reranker,voxtral], [mistral,reranker,voxtral]
assert.Len(t, expanded, 9)
// Check mega entries
megaEntries := filterBySetName(expanded, "mega")
assert.Len(t, megaEntries, 3)
assert.Equal(t, []string{"gemma", "reranker", "voxtral"}, megaEntries[0].Models)
}
func TestValidateMatrix_MapIDRequired(t *testing.T) {
// DSL cannot use real model names directly — must use map IDs
models := makeModels("gemma", "voxtral")
matrix := MatrixConfig{
Map: map[string]string{"g": "gemma"},
Sets: OrderedSets{
{Name: "combo", DSL: "g & voxtral"},
},
}
_, err := ValidateMatrix(matrix, models)
require.Error(t, err)
assert.Contains(t, err.Error(), "unknown map ID")
}
func TestValidateMatrix_InvalidAliasKey(t *testing.T) {
models := makeModels("gemma")
tests := []struct {
name string
alias string
errMsg string
}{
{"too long", "abcdefghi", "alphanumeric and 1-8 characters"},
{"has underscore", "a_b", "alphanumeric and 1-8 characters"},
{"has hyphen", "a-b", "alphanumeric and 1-8 characters"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
matrix := MatrixConfig{
Map: map[string]string{tt.alias: "gemma"},
Sets: OrderedSets{{Name: "s", DSL: tt.alias}},
}
_, err := ValidateMatrix(matrix, models)
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errMsg)
})
}
}
func TestValidateMatrix_AliasReferencesUnknownModel(t *testing.T) {
models := makeModels("gemma")
matrix := MatrixConfig{
Map: map[string]string{"x": "nonexistent"},
Sets: OrderedSets{{Name: "s", DSL: "x"}},
}
_, err := ValidateMatrix(matrix, models)
require.Error(t, err)
assert.Contains(t, err.Error(), "unknown model")
}
func TestValidateMatrix_EvictCostInvalid(t *testing.T) {
models := makeModels("gemma")
t.Run("zero cost", func(t *testing.T) {
matrix := MatrixConfig{
Map: map[string]string{"g": "gemma"},
EvictCosts: map[string]int{"g": 0},
Sets: OrderedSets{{Name: "s", DSL: "g"}},
}
_, err := ValidateMatrix(matrix, models)
require.Error(t, err)
assert.Contains(t, err.Error(), "positive integer")
})
t.Run("negative cost", func(t *testing.T) {
matrix := MatrixConfig{
Map: map[string]string{"g": "gemma"},
EvictCosts: map[string]int{"g": -1},
Sets: OrderedSets{{Name: "s", DSL: "g"}},
}
_, err := ValidateMatrix(matrix, models)
require.Error(t, err)
assert.Contains(t, err.Error(), "positive integer")
})
t.Run("unknown map ID in evict_costs", func(t *testing.T) {
matrix := MatrixConfig{
Map: map[string]string{"g": "gemma"},
EvictCosts: map[string]int{"unknown": 5},
Sets: OrderedSets{{Name: "s", DSL: "g"}},
}
_, err := ValidateMatrix(matrix, models)
require.Error(t, err)
assert.Contains(t, err.Error(), "unknown map ID")
})
}
func TestValidateMatrix_CycleDetection(t *testing.T) {
models := makeModels("gemma")
matrix := MatrixConfig{
Map: map[string]string{"g": "gemma"},
Sets: OrderedSets{
{Name: "a", DSL: "+b"},
{Name: "b", DSL: "+a"},
},
}
_, err := ValidateMatrix(matrix, models)
require.Error(t, err)
assert.Contains(t, err.Error(), "circular reference")
}
func TestValidateMatrix_UndefinedRefTarget(t *testing.T) {
models := makeModels("gemma")
matrix := MatrixConfig{
Map: map[string]string{"g": "gemma"},
Sets: OrderedSets{
{Name: "a", DSL: "+nonexistent"},
},
}
_, err := ValidateMatrix(matrix, models)
require.Error(t, err)
assert.Contains(t, err.Error(), "references undefined set")
}
func TestValidateMatrix_NoSets(t *testing.T) {
_, err := ValidateMatrix(MatrixConfig{}, makeModels("gemma"))
require.Error(t, err)
assert.Contains(t, err.Error(), "at least one set")
}
func TestValidateMatrix_UnknownMapIDInDSL(t *testing.T) {
models := makeModels("gemma")
matrix := MatrixConfig{
Map: map[string]string{"g": "gemma"},
Sets: OrderedSets{
{Name: "s", DSL: "g & nonexistent"},
},
}
_, err := ValidateMatrix(matrix, models)
require.Error(t, err)
assert.Contains(t, err.Error(), "unknown map ID")
}
func TestValidateMatrix_ResolvedEvictCosts(t *testing.T) {
mc := &MatrixConfig{
Map: map[string]string{
"g": "gemma",
"L": "llama70B",
},
EvictCosts: map[string]int{
"L": 30,
"g": 5,
},
}
costs := mc.ResolvedEvictCosts()
assert.Equal(t, 30, costs["llama70B"])
assert.Equal(t, 5, costs["gemma"])
}
func TestValidateMatrix_ConfigXOR(t *testing.T) {
// groups and matrix both defined
yaml := `
models:
model1:
cmd: echo model1
proxy: http://localhost:8080
groups:
group1:
members:
- model1
matrix:
sets:
s: "model1"
`
_, err := LoadConfigFromReader(strings.NewReader(yaml))
require.Error(t, err)
assert.Contains(t, err.Error(), "cannot use both")
}
func TestValidateMatrix_ConfigMatrixOnly(t *testing.T) {
yaml := `
models:
gemma:
cmd: echo gemma
proxy: http://localhost:8080
qwen:
cmd: echo qwen
proxy: http://localhost:8081
matrix:
map:
g: gemma
q: qwen
sets:
combo: "g | q"
`
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
require.NoError(t, err)
assert.NotNil(t, cfg.Matrix)
assert.Len(t, cfg.ExpandedSets, 2)
// Groups should be empty when matrix is used
assert.Empty(t, cfg.Groups)
}
func filterBySetName(sets []ExpandedSet, name string) []ExpandedSet {
var result []ExpandedSet
for _, s := range sets {
if s.SetName == name {
result = append(result, s)
}
}
return result
}
+274
View File
@@ -0,0 +1,274 @@
package proxy
import (
"fmt"
"net/http"
"slices"
"sort"
"sync"
"github.com/mostlygeek/llama-swap/proxy/config"
)
// MatrixSolver contains pure swap-decision logic with no Process dependencies.
// It is safe for concurrent reads after construction.
type MatrixSolver struct {
expandedSets []config.ExpandedSet // all valid model combinations
evictCosts map[string]int // real model name -> eviction cost (default 1)
modelToSets map[string][]int // model name -> indices into expandedSets
}
// NewMatrixSolver builds a solver from expanded sets and eviction costs.
func NewMatrixSolver(expandedSets []config.ExpandedSet, evictCosts map[string]int) *MatrixSolver {
modelToSets := make(map[string][]int)
for i, es := range expandedSets {
for _, model := range es.Models {
modelToSets[model] = append(modelToSets[model], i)
}
}
return &MatrixSolver{
expandedSets: expandedSets,
evictCosts: evictCosts,
modelToSets: modelToSets,
}
}
// SolveResult describes what the solver decided.
type SolveResult struct {
Evict []string // running models that must be stopped
TargetSet []string // the chosen set of models (for informational purposes)
SetName string // name of the chosen set
DSL string // original DSL expression for the chosen set
TotalCost int // total eviction cost
}
// Solve determines which models to evict when a model is requested.
//
// Algorithm:
// 1. If requestedModel is already running, no eviction needed.
// 2. Find all sets containing requestedModel.
// 3. If no sets found, the model runs alone; evict all running models.
// 4. For each candidate set, compute cost = sum of evict_costs for running
// models NOT in that set.
// 5. Pick lowest cost. Ties broken by definition order (index in expandedSets).
// 6. Return models to evict and the chosen set.
func (s *MatrixSolver) Solve(requestedModel string, runningModels []string) (SolveResult, error) {
// If already running, nothing to do
if slices.Contains(runningModels, requestedModel) {
return SolveResult{}, nil
}
candidateIndices := s.modelToSets[requestedModel]
// Model not in any set: runs alone, evict everything
if len(candidateIndices) == 0 {
evict := make([]string, len(runningModels))
copy(evict, runningModels)
return SolveResult{
Evict: evict,
TargetSet: []string{requestedModel},
}, nil
}
// Find the cheapest candidate set
bestCost := -1
bestIdx := -1
for _, idx := range candidateIndices {
setModels := s.expandedSets[idx].Models
cost := 0
for _, running := range runningModels {
if !slices.Contains(setModels, running) {
cost += s.evictCost(running)
}
}
if bestCost < 0 || cost < bestCost || (cost == bestCost && idx < bestIdx) {
bestCost = cost
bestIdx = idx
}
}
// Determine which running models to evict
chosen := s.expandedSets[bestIdx]
var evict []string
for _, running := range runningModels {
if !slices.Contains(chosen.Models, running) {
evict = append(evict, running)
}
}
return SolveResult{
Evict: evict,
TargetSet: chosen.Models,
SetName: chosen.SetName,
DSL: chosen.DSL,
TotalCost: bestCost,
}, nil
}
func (s *MatrixSolver) evictCost(model string) int {
if cost, ok := s.evictCosts[model]; ok {
return cost
}
return 1
}
// Matrix manages processes using solver-based swap logic.
type Matrix struct {
sync.Mutex
solver *MatrixSolver
processes map[string]*Process // all processes keyed by real model name
config config.Config
proxyLogger *LogMonitor
upstreamLogger *LogMonitor
}
// NewMatrix creates a Matrix from config. It creates a Process for every
// model defined in the config (any model can run alone even if not in a set).
func NewMatrix(cfg config.Config, proxyLogger, upstreamLogger *LogMonitor) *Matrix {
processes := make(map[string]*Process)
for modelID, modelConfig := range cfg.Models {
processLogger := NewLogMonitorWriter(upstreamLogger)
process := NewProcess(modelID, cfg.HealthCheckTimeout, modelConfig, processLogger, proxyLogger)
processes[modelID] = process
}
evictCosts := cfg.Matrix.ResolvedEvictCosts()
return &Matrix{
solver: NewMatrixSolver(cfg.ExpandedSets, evictCosts),
processes: processes,
config: cfg,
proxyLogger: proxyLogger,
upstreamLogger: upstreamLogger,
}
}
// ProxyRequest handles the swap logic and proxies the request to the model.
func (m *Matrix) ProxyRequest(modelID string, w http.ResponseWriter, r *http.Request) error {
process, ok := m.processes[modelID]
if !ok {
return fmt.Errorf("model %s not found in matrix", modelID)
}
m.Lock()
running := m.runningModels()
result, err := m.solver.Solve(modelID, running)
if err != nil {
m.Unlock()
return fmt.Errorf("matrix solver error: %w", err)
}
// Log solver decision
if len(result.Evict) > 0 {
m.proxyLogger.Debugf("Matrix: model=%s set=%s dsl=%q evict=%v target=%v cost=%d",
modelID, result.SetName, result.DSL, result.Evict, result.TargetSet, result.TotalCost)
} else if len(running) == 0 {
m.proxyLogger.Debugf("Matrix: model=%s starting (no models running)", modelID)
} else {
m.proxyLogger.Debugf("Matrix: model=%s already running in set=%s dsl=%q", modelID, result.SetName, result.DSL)
}
// Evict models that need to be stopped
if len(result.Evict) > 0 {
var wg sync.WaitGroup
for _, evictModel := range result.Evict {
if p, exists := m.processes[evictModel]; exists {
wg.Add(1)
go func(p *Process) {
defer wg.Done()
p.Stop()
}(p)
}
}
wg.Wait()
}
m.Unlock()
// Proxy the request (Process handles on-demand start)
process.ProxyRequest(w, r)
return nil
}
// StopProcesses stops all running processes.
func (m *Matrix) StopProcesses(strategy StopStrategy) {
m.Lock()
defer m.Unlock()
var wg sync.WaitGroup
for _, process := range m.processes {
wg.Add(1)
go func(p *Process) {
defer wg.Done()
switch strategy {
case StopImmediately:
p.StopImmediately()
default:
p.Stop()
}
}(process)
}
wg.Wait()
}
// StopProcess stops a single process by model ID.
func (m *Matrix) StopProcess(modelID string, strategy StopStrategy) error {
process, ok := m.processes[modelID]
if !ok {
return fmt.Errorf("process not found for %s", modelID)
}
switch strategy {
case StopImmediately:
process.StopImmediately()
default:
process.Stop()
}
return nil
}
// Shutdown shuts down all processes.
func (m *Matrix) Shutdown() {
var wg sync.WaitGroup
for _, process := range m.processes {
wg.Add(1)
go func(p *Process) {
defer wg.Done()
p.Shutdown()
}(process)
}
wg.Wait()
}
// RunningModels returns model names currently in StateReady.
func (m *Matrix) RunningModels() []string {
m.Lock()
defer m.Unlock()
return m.runningModels()
}
// runningModels returns running model names (caller must hold lock).
func (m *Matrix) runningModels() []string {
var running []string
for id, process := range m.processes {
if process.CurrentState() == StateReady {
running = append(running, id)
}
}
sort.Strings(running)
return running
}
// GetProcess returns the Process for a model.
func (m *Matrix) GetProcess(modelID string) (*Process, bool) {
p, ok := m.processes[modelID]
return p, ok
}
// HasModel returns true if the model is managed by this matrix.
func (m *Matrix) HasModel(modelID string) bool {
_, ok := m.processes[modelID]
return ok
}
+226
View File
@@ -0,0 +1,226 @@
package proxy
import (
"testing"
"github.com/mostlygeek/llama-swap/proxy/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Helper to build expanded sets for solver tests
func makeExpandedSets(sets ...struct {
name string
models []string
}) []config.ExpandedSet {
var result []config.ExpandedSet
for _, s := range sets {
result = append(result, config.ExpandedSet{
SetName: s.name,
Models: s.models,
})
}
return result
}
func es(name string, models ...string) struct {
name string
models []string
} {
return struct {
name string
models []string
}{name, models}
}
func TestMatrixSolver_AlreadyRunning(t *testing.T) {
solver := NewMatrixSolver(
makeExpandedSets(es("s1", "a", "b")),
nil,
)
result, err := solver.Solve("a", []string{"a"})
require.NoError(t, err)
assert.Empty(t, result.Evict)
assert.Nil(t, result.TargetSet)
}
func TestMatrixSolver_NotInAnySet_RunsAlone(t *testing.T) {
solver := NewMatrixSolver(
makeExpandedSets(es("s1", "a", "b")),
nil,
)
// Model "c" not in any set
result, err := solver.Solve("c", []string{"a", "b"})
require.NoError(t, err)
assert.ElementsMatch(t, []string{"a", "b"}, result.Evict)
assert.Equal(t, []string{"c"}, result.TargetSet)
}
func TestMatrixSolver_NotInAnySet_NothingRunning(t *testing.T) {
solver := NewMatrixSolver(
makeExpandedSets(es("s1", "a", "b")),
nil,
)
result, err := solver.Solve("c", []string{})
require.NoError(t, err)
assert.Empty(t, result.Evict)
assert.Equal(t, []string{"c"}, result.TargetSet)
}
func TestMatrixSolver_SingleSet_EvictsNonMembers(t *testing.T) {
// Set: [a, b]. Request a when b and c are running.
solver := NewMatrixSolver(
makeExpandedSets(es("s1", "a", "b")),
nil,
)
result, err := solver.Solve("a", []string{"b", "c"})
require.NoError(t, err)
// c is not in the set, so it gets evicted. b is in the set, so it stays.
assert.Equal(t, []string{"c"}, result.Evict)
assert.Equal(t, []string{"a", "b"}, result.TargetSet)
}
func TestMatrixSolver_PicksLowestCost(t *testing.T) {
// Two sets containing model "a":
// s1: [a, v] — if v is running, cost=0; if L is running, cost=30
// s2: [a, L] — if L is running, cost=0; if v is running, cost=50
solver := NewMatrixSolver(
makeExpandedSets(
es("s1", "a", "v"),
es("s2", "a", "L"),
),
map[string]int{"v": 50, "L": 30},
)
// v is running. Switching to a:
// s1 cost: v is in s1, so 0
// s2 cost: v is NOT in s2, so 50
// => pick s1
result, err := solver.Solve("a", []string{"v"})
require.NoError(t, err)
assert.Empty(t, result.Evict)
assert.Equal(t, []string{"a", "v"}, result.TargetSet)
// L is running. Switching to a:
// s1 cost: L is NOT in s1, so 30
// s2 cost: L is in s2, so 0
// => pick s2
result, err = solver.Solve("a", []string{"L"})
require.NoError(t, err)
assert.Empty(t, result.Evict)
assert.Equal(t, []string{"a", "L"}, result.TargetSet)
}
func TestMatrixSolver_TieBreakingByDefinitionOrder(t *testing.T) {
// Two sets with identical cost. Definition order should win.
solver := NewMatrixSolver(
makeExpandedSets(
es("s1", "a", "x"),
es("s2", "a", "y"),
),
nil,
)
// Nothing running, both sets cost 0. s1 is first.
result, err := solver.Solve("a", []string{})
require.NoError(t, err)
assert.Empty(t, result.Evict)
assert.Equal(t, []string{"a", "x"}, result.TargetSet)
}
func TestMatrixSolver_EvictCostPreservesExpensive(t *testing.T) {
// Model "v" costs 50 to evict, "m" costs 1 (default).
// Sets: [g,v], [g,m]
// Running: v, m. Request g.
// s1=[g,v]: evict m (cost 1), keep v
// s2=[g,m]: evict v (cost 50), keep m
// => pick s1
solver := NewMatrixSolver(
makeExpandedSets(
es("s1", "g", "v"),
es("s2", "g", "m"),
),
map[string]int{"v": 50},
)
result, err := solver.Solve("g", []string{"v", "m"})
require.NoError(t, err)
assert.Equal(t, []string{"m"}, result.Evict)
assert.Equal(t, []string{"g", "v"}, result.TargetSet)
}
func TestMatrixSolver_NothingRunning(t *testing.T) {
solver := NewMatrixSolver(
makeExpandedSets(
es("s1", "g", "v"),
es("s2", "q", "v"),
),
nil,
)
result, err := solver.Solve("g", []string{})
require.NoError(t, err)
assert.Empty(t, result.Evict)
assert.Equal(t, []string{"g", "v"}, result.TargetSet)
}
func TestMatrixSolver_FullScenario(t *testing.T) {
// Simulates the example config:
// standard: [g,v], [q,v], [m,v]
// with_rerank: [g,v,e], [q,v,e]
// creative: [g,sd], [q,sd]
// full: [L]
solver := NewMatrixSolver(
makeExpandedSets(
es("standard", "g", "v"),
es("standard", "q", "v"),
es("standard", "m", "v"),
es("with_rerank", "e", "g", "v"),
es("with_rerank", "e", "q", "v"),
es("creative", "g", "sd"),
es("creative", "q", "sd"),
es("full", "L"),
),
map[string]int{"v": 50, "L": 30, "whisper": 10},
)
// Running: g, v. Request q.
// standard[q,v]: evict g (cost 1), keep v. Total: 1.
// with_rerank[q,v,e]: evict g (cost 1), keep v. Total: 1.
// => tie, pick first by definition order = standard[q,v]
result, err := solver.Solve("q", []string{"g", "v"})
require.NoError(t, err)
assert.Equal(t, []string{"g"}, result.Evict)
assert.Equal(t, []string{"q", "v"}, result.TargetSet)
// Running: g, v. Request L.
// full[L]: evict g (cost 1) + v (cost 50). Total: 51.
// Only one set contains L, so pick it.
result, err = solver.Solve("L", []string{"g", "v"})
require.NoError(t, err)
assert.ElementsMatch(t, []string{"g", "v"}, result.Evict)
assert.Equal(t, []string{"L"}, result.TargetSet)
// Running: g, v. Request sd.
// creative[g,sd]: evict v (cost 50). Total: 50.
// creative[q,sd]: evict g (cost 1) + v (cost 50). Total: 51.
// => pick creative[g,sd]
result, err = solver.Solve("sd", []string{"g", "v"})
require.NoError(t, err)
assert.Equal(t, []string{"v"}, result.Evict)
assert.Equal(t, []string{"g", "sd"}, result.TargetSet)
// Running: q, v, e. Request g.
// standard[g,v]: evict q (1) + e (1). Total: 2.
// with_rerank[g,v,e]: evict q (1). Total: 1.
// creative[g,sd]: evict q (1) + v (50) + e (1). Total: 52.
// => pick with_rerank[g,v,e]
result, err = solver.Solve("g", []string{"e", "q", "v"})
require.NoError(t, err)
assert.Equal(t, []string{"q"}, result.Evict)
assert.Equal(t, []string{"e", "g", "v"}, result.TargetSet)
}
+99 -34
View File
@@ -77,6 +77,9 @@ type ProxyManager struct {
processGroups map[string]*ProcessGroup
// matrix-based swap (mutually exclusive with processGroups)
matrix *Matrix
inFlightCounter *InflightCounter
// shutdown signaling
@@ -203,10 +206,14 @@ func New(proxyConfig config.Config) *ProxyManager {
peerProxy: peerProxy,
}
// create the process groups
for groupID := range proxyConfig.Groups {
processGroup := NewProcessGroup(groupID, proxyConfig, proxyLogger, upstreamLogger)
pm.processGroups[groupID] = processGroup
// create either matrix or process groups (mutually exclusive)
if proxyConfig.Matrix != nil {
pm.matrix = NewMatrix(proxyConfig, proxyLogger, upstreamLogger)
} else {
for groupID := range proxyConfig.Groups {
processGroup := NewProcessGroup(groupID, proxyConfig, proxyLogger, upstreamLogger)
pm.processGroups[groupID] = processGroup
}
}
pm.setupGinEngine()
@@ -225,18 +232,29 @@ func New(proxyConfig config.Config) *ProxyManager {
}
proxyLogger.Infof("Preloading model: %s", modelID)
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
var preloadErr error
req, _ := http.NewRequest("GET", "/", nil)
if pm.matrix != nil {
preloadErr = pm.matrix.ProxyRequest(modelID, discardWriter, req)
} else {
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
preloadErr = err
} else {
preloadErr = processGroup.ProxyRequest(modelID, discardWriter, req)
}
}
if preloadErr != nil {
event.Emit(ModelPreloadedEvent{
ModelName: modelID,
Success: false,
})
proxyLogger.Errorf("Failed to preload model %s: %v", modelID, err)
proxyLogger.Errorf("Failed to preload model %s: %v", modelID, preloadErr)
continue
} else {
req, _ := http.NewRequest("GET", "/", nil)
processGroup.ProxyRequest(modelID, discardWriter, req)
event.Emit(ModelPreloadedEvent{
ModelName: modelID,
Success: true,
@@ -453,6 +471,11 @@ func (pm *ProxyManager) StopProcesses(strategy StopStrategy) {
pm.Lock()
defer pm.Unlock()
if pm.matrix != nil {
pm.matrix.StopProcesses(strategy)
return
}
// stop Processes in parallel
var wg sync.WaitGroup
for _, processGroup := range pm.processGroups {
@@ -473,6 +496,12 @@ func (pm *ProxyManager) Shutdown() {
pm.proxyLogger.Debug("Shutdown() called in proxy manager")
if pm.matrix != nil {
pm.matrix.Shutdown()
pm.shutdownCancel()
return
}
var wg sync.WaitGroup
// Send shutdown signal to all process in groups
for _, processGroup := range pm.processGroups {
@@ -639,10 +668,16 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
return
}
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
var handler func(string, http.ResponseWriter, *http.Request) error
if pm.matrix != nil {
handler = pm.matrix.ProxyRequest
} else {
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
handler = processGroup.ProxyRequest
}
// rewrite the path
@@ -651,13 +686,13 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
// attempt to record metrics if it is a POST request
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, handler); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", modelID, originalPath)
return
}
} else {
if err := processGroup.ProxyRequest(modelID, c.Writer, c.Request); err != nil {
if err := handler(modelID, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", modelID, originalPath)
return
@@ -683,10 +718,16 @@ func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) {
modelID, found := pm.config.RealModelName(requestedModel)
if found {
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
var localHandler func(string, http.ResponseWriter, *http.Request) error
if pm.matrix != nil {
localHandler = pm.matrix.ProxyRequest
} else {
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
localHandler = processGroup.ProxyRequest
}
// issue #69 allow custom model names to be sent to upstream
@@ -737,7 +778,7 @@ func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) {
}
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
nextHandler = processGroup.ProxyRequest
nextHandler = localHandler
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
modelID = requestedModel
@@ -823,15 +864,19 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
modelID, found := pm.config.RealModelName(requestedModel)
if found {
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
if pm.matrix != nil {
nextHandler = pm.matrix.ProxyRequest
} else {
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
nextHandler = processGroup.ProxyRequest
}
useModelName = pm.config.Models[modelID].UseModelName
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
nextHandler = processGroup.ProxyRequest
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
modelID = requestedModel
@@ -942,14 +987,18 @@ func (pm *ProxyManager) proxyGETModelHandler(c *gin.Context) {
var modelID string
if realModelID, found := pm.config.RealModelName(requestedModel); found {
processGroup, err := pm.swapProcessGroup(realModelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
modelID = realModelID
if pm.matrix != nil {
nextHandler = pm.matrix.ProxyRequest
} else {
processGroup, err := pm.swapProcessGroup(realModelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
nextHandler = processGroup.ProxyRequest
}
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
nextHandler = processGroup.ProxyRequest
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
modelID = requestedModel
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
@@ -1048,9 +1097,9 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
context.Header("Content-Type", "application/json")
runningProcesses := make([]gin.H, 0) // Default to an empty response.
for _, processGroup := range pm.processGroups {
for _, process := range processGroup.processes {
if process.CurrentState() == StateReady {
if pm.matrix != nil {
for _, modelID := range pm.matrix.RunningModels() {
if process, ok := pm.matrix.GetProcess(modelID); ok {
runningProcesses = append(runningProcesses, gin.H{
"model": process.ID,
"state": process.state,
@@ -1062,6 +1111,22 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
})
}
}
} else {
for _, processGroup := range pm.processGroups {
for _, process := range processGroup.processes {
if process.CurrentState() == StateReady {
runningProcesses = append(runningProcesses, gin.H{
"model": process.ID,
"state": process.state,
"cmd": process.config.Cmd,
"proxy": process.config.Proxy,
"ttl": process.config.UnloadAfter,
"name": process.config.Name,
"description": process.config.Description,
})
}
}
}
}
// Put the results under the `running` key.
+34 -28
View File
@@ -55,27 +55,28 @@ func (pm *ProxyManager) getModelStatus() []Model {
// Iterate over sorted keys
for _, modelID := range modelIDs {
// Get process state
processGroup := pm.findGroupByModelName(modelID)
state := "unknown"
if processGroup != nil {
process := processGroup.processes[modelID]
if process != nil {
var stateStr string
switch process.CurrentState() {
case StateReady:
stateStr = "ready"
case StateStarting:
stateStr = "starting"
case StateStopping:
stateStr = "stopping"
case StateShutdown:
stateStr = "shutdown"
case StateStopped:
stateStr = "stopped"
default:
stateStr = "unknown"
}
state = stateStr
var process *Process
if pm.matrix != nil {
process, _ = pm.matrix.GetProcess(modelID)
} else {
processGroup := pm.findGroupByModelName(modelID)
if processGroup != nil {
process = processGroup.processes[modelID]
}
}
if process != nil {
switch process.CurrentState() {
case StateReady:
state = "ready"
case StateStarting:
state = "starting"
case StateStopping:
state = "stopping"
case StateShutdown:
state = "shutdown"
case StateStopped:
state = "stopped"
}
}
models = append(models, Model{
@@ -254,18 +255,23 @@ func (pm *ProxyManager) apiUnloadSingleModelHandler(c *gin.Context) {
return
}
processGroup := pm.findGroupByModelName(realModelName)
if processGroup == nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("process group not found for model %s", requestedModel))
return
var stopErr error
if pm.matrix != nil {
stopErr = pm.matrix.StopProcess(realModelName, StopImmediately)
} else {
processGroup := pm.findGroupByModelName(realModelName)
if processGroup == nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("process group not found for model %s", requestedModel))
return
}
stopErr = processGroup.StopProcess(realModelName, StopImmediately)
}
if err := processGroup.StopProcess(realModelName, StopImmediately); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stopping process: %s", err.Error()))
if stopErr != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stopping process: %s", stopErr.Error()))
return
} else {
c.String(http.StatusOK, "OK")
}
c.String(http.StatusOK, "OK")
}
func (pm *ProxyManager) apiGetVersion(c *gin.Context) {