fix: dereference orchestra-skills submodule, add as plain files

This commit is contained in:
2026-05-05 23:28:24 +02:00
parent 964c1dacc9
commit b275af2b4d
530 changed files with 221754 additions and 1 deletions
Vendored
BIN
View File
Binary file not shown.
BIN
View File
Binary file not shown.
Vendored Executable
BIN
View File
Binary file not shown.
Submodule orchestra-skills deleted from 28f2d29236
@@ -0,0 +1,293 @@
{
"name": "ai-research-skills",
"owner": {
"name": "Orchestra Research",
"email": "zechen@orchestra-research.com"
},
"metadata": {
"description": "Comprehensive library of 98 AI research engineering skills enabling autonomous AI research from hypothesis to experimental verification",
"version": "1.2.0"
},
"plugins": [
{
"name": "model-architecture",
"description": "LLM architectures and implementations including LitGPT, Mamba, NanoGPT, RWKV, and TorchTitan. Use when implementing, training, or understanding transformer and alternative architectures.",
"source": "./",
"strict": false,
"skills": [
"./01-model-architecture/litgpt",
"./01-model-architecture/mamba",
"./01-model-architecture/nanogpt",
"./01-model-architecture/rwkv",
"./01-model-architecture/torchtitan"
]
},
{
"name": "tokenization",
"description": "Text tokenization for LLMs including HuggingFace Tokenizers and SentencePiece. Use when training custom tokenizers or handling multilingual text.",
"source": "./",
"strict": false,
"skills": [
"./02-tokenization/huggingface-tokenizers",
"./02-tokenization/sentencepiece"
]
},
{
"name": "fine-tuning",
"description": "LLM fine-tuning frameworks including Axolotl, LLaMA-Factory, PEFT, and Unsloth. Use when fine-tuning models with LoRA, QLoRA, or full fine-tuning.",
"source": "./",
"strict": false,
"skills": [
"./03-fine-tuning/axolotl",
"./03-fine-tuning/llama-factory",
"./03-fine-tuning/peft",
"./03-fine-tuning/unsloth"
]
},
{
"name": "mechanistic-interpretability",
"description": "Neural network interpretability tools including TransformerLens, SAELens, NNSight, and pyvene. Use when analyzing model internals, finding circuits, or understanding how models compute.",
"source": "./",
"strict": false,
"skills": [
"./04-mechanistic-interpretability/nnsight",
"./04-mechanistic-interpretability/pyvene",
"./04-mechanistic-interpretability/saelens",
"./04-mechanistic-interpretability/transformer-lens"
]
},
{
"name": "data-processing",
"description": "Data curation and processing at scale including NeMo Curator and Ray Data. Use when preparing training datasets or processing large-scale data.",
"source": "./",
"strict": false,
"skills": [
"./05-data-processing/nemo-curator",
"./05-data-processing/ray-data"
]
},
{
"name": "post-training",
"description": "RLHF and preference alignment including TRL, GRPO, OpenRLHF, SimPO, verl, slime, miles, and torchforge. Use when aligning models with human preferences, training reward models, or large-scale RL training.",
"source": "./",
"strict": false,
"skills": [
"./06-post-training/grpo-rl-training",
"./06-post-training/miles",
"./06-post-training/openrlhf",
"./06-post-training/simpo",
"./06-post-training/slime",
"./06-post-training/torchforge",
"./06-post-training/trl-fine-tuning",
"./06-post-training/verl"
]
},
{
"name": "safety-alignment",
"description": "AI safety and content moderation including Constitutional AI, LlamaGuard, NeMo Guardrails, and Prompt Guard. Use when implementing safety filters, content moderation, or prompt injection detection.",
"source": "./",
"strict": false,
"skills": [
"./07-safety-alignment/constitutional-ai",
"./07-safety-alignment/llamaguard",
"./07-safety-alignment/nemo-guardrails",
"./07-safety-alignment/prompt-guard"
]
},
{
"name": "distributed-training",
"description": "Multi-GPU and multi-node training including DeepSpeed, PyTorch FSDP, Accelerate, Megatron-Core, PyTorch Lightning, and Ray Train. Use when training large models across GPUs.",
"source": "./",
"strict": false,
"skills": [
"./08-distributed-training/accelerate",
"./08-distributed-training/deepspeed",
"./08-distributed-training/megatron-core",
"./08-distributed-training/pytorch-fsdp2",
"./08-distributed-training/pytorch-lightning",
"./08-distributed-training/ray-train"
]
},
{
"name": "infrastructure",
"description": "GPU cloud and compute orchestration including Modal, Lambda Labs, and SkyPilot. Use when deploying training jobs or managing GPU resources.",
"source": "./",
"strict": false,
"skills": [
"./09-infrastructure/lambda-labs",
"./09-infrastructure/modal",
"./09-infrastructure/skypilot"
]
},
{
"name": "optimization",
"description": "Model optimization and quantization including Flash Attention, bitsandbytes, GPTQ, AWQ, GGUF, and HQQ. Use when reducing memory, accelerating inference, or quantizing models.",
"source": "./",
"strict": false,
"skills": [
"./10-optimization/awq",
"./10-optimization/bitsandbytes",
"./10-optimization/flash-attention",
"./10-optimization/gguf",
"./10-optimization/gptq",
"./10-optimization/hqq",
"./10-optimization/ml-training-recipes"
]
},
{
"name": "evaluation",
"description": "LLM benchmarking and evaluation including lm-evaluation-harness, BigCode Evaluation Harness, and NeMo Evaluator. Use when benchmarking models or measuring performance.",
"source": "./",
"strict": false,
"skills": [
"./11-evaluation/bigcode-evaluation-harness",
"./11-evaluation/lm-evaluation-harness",
"./11-evaluation/nemo-evaluator"
]
},
{
"name": "inference-serving",
"description": "Production LLM inference including vLLM, TensorRT-LLM, llama.cpp, and SGLang. Use when deploying models for production inference.",
"source": "./",
"strict": false,
"skills": [
"./12-inference-serving/llama-cpp",
"./12-inference-serving/sglang",
"./12-inference-serving/tensorrt-llm",
"./12-inference-serving/vllm"
]
},
{
"name": "mlops",
"description": "ML experiment tracking and lifecycle including Weights & Biases, MLflow, and TensorBoard. Use when tracking experiments or managing models.",
"source": "./",
"strict": false,
"skills": [
"./13-mlops/mlflow",
"./13-mlops/tensorboard",
"./13-mlops/weights-and-biases"
]
},
{
"name": "agents",
"description": "LLM agent frameworks including LangChain, LlamaIndex, CrewAI, and AutoGPT. Use when building chatbots, autonomous agents, or tool-using systems.",
"source": "./",
"strict": false,
"skills": [
"./14-agents/autogpt",
"./14-agents/crewai",
"./14-agents/langchain",
"./14-agents/llamaindex"
]
},
{
"name": "rag",
"description": "Retrieval-Augmented Generation including Chroma, FAISS, Pinecone, Qdrant, and Sentence Transformers. Use when building semantic search or document retrieval systems.",
"source": "./",
"strict": false,
"skills": [
"./15-rag/chroma",
"./15-rag/faiss",
"./15-rag/pinecone",
"./15-rag/qdrant",
"./15-rag/sentence-transformers"
]
},
{
"name": "prompt-engineering",
"description": "Structured LLM outputs including DSPy, Instructor, Guidance, and Outlines. Use when extracting structured data or constraining LLM outputs.",
"source": "./",
"strict": false,
"skills": [
"./16-prompt-engineering/dspy",
"./16-prompt-engineering/guidance",
"./16-prompt-engineering/instructor",
"./16-prompt-engineering/outlines"
]
},
{
"name": "observability",
"description": "LLM application monitoring including LangSmith and Phoenix. Use when debugging LLM apps or monitoring production systems.",
"source": "./",
"strict": false,
"skills": [
"./17-observability/langsmith",
"./17-observability/phoenix"
]
},
{
"name": "multimodal",
"description": "Vision, audio, and multimodal models including CLIP, Whisper, LLaVA, BLIP-2, Segment Anything, Stable Diffusion, AudioCraft, Cosmos Policy, OpenPI, and OpenVLA-OFT. Use when working with images, audio, multimodal tasks, or vision-language-action robot policies.",
"source": "./",
"strict": false,
"skills": [
"./18-multimodal/audiocraft",
"./18-multimodal/blip-2",
"./18-multimodal/clip",
"./18-multimodal/cosmos-policy",
"./18-multimodal/llava",
"./18-multimodal/openpi",
"./18-multimodal/openvla-oft",
"./18-multimodal/segment-anything",
"./18-multimodal/stable-diffusion",
"./18-multimodal/whisper"
]
},
{
"name": "emerging-techniques",
"description": "Advanced ML techniques including MoE Training, Model Merging, Long Context, Speculative Decoding, Knowledge Distillation, and Model Pruning. Use when implementing cutting-edge optimization or architecture techniques.",
"source": "./",
"strict": false,
"skills": [
"./19-emerging-techniques/knowledge-distillation",
"./19-emerging-techniques/long-context",
"./19-emerging-techniques/model-merging",
"./19-emerging-techniques/model-pruning",
"./19-emerging-techniques/moe-training",
"./19-emerging-techniques/speculative-decoding"
]
},
{
"name": "autoresearch",
"description": "Autonomous research orchestration using a two-loop architecture. Manages the full research lifecycle from literature survey to paper writing, routing to domain-specific skills for execution. Use when starting a research project, running autonomous experiments, or managing multi-hypothesis research.",
"source": "./",
"strict": false,
"skills": [
"./0-autoresearch-skill"
]
},
{
"name": "ml-paper-writing",
"description": "Write publication-ready ML/AI/Systems papers for NeurIPS, ICML, ICLR, ACL, AAAI, COLM, OSDI, NSDI, ASPLOS, SOSP. Includes LaTeX templates, citation verification, reviewer guidelines, publication-quality figure generation, systems paper structural blueprints, and conference presentation slides.",
"source": "./",
"strict": false,
"skills": [
"./20-ml-paper-writing/ml-paper-writing",
"./20-ml-paper-writing/academic-plotting",
"./20-ml-paper-writing/systems-paper-writing",
"./20-ml-paper-writing/presenting-conference-talks"
]
},
{
"name": "ideation",
"description": "Research ideation frameworks including structured brainstorming and creative thinking. Use when exploring new research directions, generating novel ideas, or seeking fresh angles on existing work.",
"source": "./",
"strict": false,
"skills": [
"./21-research-ideation/brainstorming-research-ideas",
"./21-research-ideation/creative-thinking-for-research"
]
},
{
"name": "agent-native-research-artifact",
"description": "Agent-Native Research Artifact (ARA) tooling: compile any research input (paper, repo, notes) into a structured artifact, record session provenance as a post-task epilogue, and run Seal Level 2 epistemic review. Use when ingesting research into a falsifiable, agent-traversable artifact, capturing how a research project actually evolved, or auditing an ARA for evidence-claim alignment.",
"source": "./",
"strict": false,
"skills": [
"./22-agent-native-research-artifact/compiler",
"./22-agent-native-research-artifact/research-manager",
"./22-agent-native-research-artifact/rigor-reviewer"
]
}
]
}
+27
View File
@@ -0,0 +1,27 @@
name: Claude Code
on:
issue_comment:
types: [created]
pull_request_review_comment:
types: [created]
issues:
types: [opened, assigned]
permissions:
contents: write
pull-requests: write
issues: write
jobs:
claude:
if: |
(github.event_name == 'issue_comment' && contains(github.event.comment.body, '@claude') && contains(fromJSON('["OWNER", "MEMBER", "COLLABORATOR"]'), github.event.comment.author_association)) ||
(github.event_name == 'pull_request_review_comment' && contains(github.event.comment.body, '@claude') && contains(fromJSON('["OWNER", "MEMBER", "COLLABORATOR"]'), github.event.comment.author_association)) ||
(github.event_name == 'issues' && contains(github.event.issue.body, '@claude') && contains(fromJSON('["OWNER", "MEMBER", "COLLABORATOR"]'), github.event.issue.author_association))
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: anthropics/claude-code-action@v1
with:
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
github_token: ${{ secrets.GITHUB_TOKEN }}
+73
View File
@@ -0,0 +1,73 @@
name: Publish to npm
on:
push:
branches: [main]
paths:
- 'packages/ai-research-skills/**'
permissions:
id-token: write
contents: read
jobs:
publish:
runs-on: ubuntu-latest
defaults:
run:
working-directory: packages/ai-research-skills
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 2
- name: Check if version changed
id: version
run: |
CURRENT=$(node -p "require('./package.json').version")
PREVIOUS=$(git show HEAD~1:packages/ai-research-skills/package.json 2>/dev/null | node -p "JSON.parse(require('fs').readFileSync('/dev/stdin','utf8')).version" 2>/dev/null || echo "")
echo "current=$CURRENT"
echo "previous=$PREVIOUS"
if [ "$CURRENT" != "$PREVIOUS" ]; then
echo "changed=true" >> $GITHUB_OUTPUT
echo "version=$CURRENT" >> $GITHUB_OUTPUT
else
echo "changed=false" >> $GITHUB_OUTPUT
fi
- name: Check if version already published
if: steps.version.outputs.changed == 'true'
id: published
run: |
VERSION=${{ steps.version.outputs.version }}
if npm view @orchestra-research/ai-research-skills@$VERSION version 2>/dev/null; then
echo "already_published=true" >> $GITHUB_OUTPUT
echo "Version $VERSION already on npm, skipping"
else
echo "already_published=false" >> $GITHUB_OUTPUT
fi
- name: Setup Node.js
if: steps.version.outputs.changed == 'true' && steps.published.outputs.already_published == 'false'
uses: actions/setup-node@v4
with:
node-version: '24'
registry-url: 'https://registry.npmjs.org'
- name: Install dependencies
if: steps.version.outputs.changed == 'true' && steps.published.outputs.already_published == 'false'
run: npm ci
- name: Publish to npm
if: steps.version.outputs.changed == 'true' && steps.published.outputs.already_published == 'false'
run: |
echo "Publishing v${{ steps.version.outputs.version }} to npm..."
unset NODE_AUTH_TOKEN
npm config delete //registry.npmjs.org/:_authToken || true
npm publish --access public --provenance
- name: Skip reason
if: steps.version.outputs.changed != 'true'
run: echo "Version unchanged, skipping publish"
+199
View File
@@ -0,0 +1,199 @@
name: Sync Skills to Orchestra
on:
push:
branches:
- main
workflow_dispatch: # Allow manual trigger
jobs:
sync-skills:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 2 # Fetch last 2 commits to detect changes
- name: Detect changed skill folders
id: changes
run: |
# Get list of changed files in last commit
CHANGED_FILES=$(git diff --name-only HEAD^..HEAD)
echo "Changed files:"
echo "$CHANGED_FILES"
# Find skill directories - supports two patterns:
# Pattern 1: XX-category/skill-name/SKILL.md (nested skills)
# Pattern 2: XX-category/SKILL.md (standalone skills like 20-ml-paper-writing)
SKILL_DIRS=""
# Pattern 1: Nested skills (XX-category/skill-name/)
NESTED=$(echo "$CHANGED_FILES" | grep -E '^[0-9]{2}-[^/]+/[^/]+/' | sed -E 's|^([0-9]{2}-[^/]+/[^/]+)/.*|\1|' | sort -u)
if [ -n "$NESTED" ]; then
SKILL_DIRS="$NESTED"
fi
# Pattern 2: Standalone skills (XX-category/ with SKILL.md directly inside)
STANDALONE=$(echo "$CHANGED_FILES" | grep -E '^[0-9]{2}-[^/]+/SKILL\.md$' | sed -E 's|^([0-9]{2}-[^/]+)/SKILL\.md$|\1|' | sort -u)
if [ -n "$STANDALONE" ]; then
if [ -n "$SKILL_DIRS" ]; then
SKILL_DIRS=$(printf "%s\n%s" "$SKILL_DIRS" "$STANDALONE" | sort -u)
else
SKILL_DIRS="$STANDALONE"
fi
fi
echo "Changed skill directories:"
echo "$SKILL_DIRS"
# Convert to JSON array for matrix
if [ -z "$SKILL_DIRS" ]; then
SKILLS_JSON="[]"
SKILL_COUNT=0
else
SKILLS_JSON=$(echo "$SKILL_DIRS" | jq -R -s -c 'split("\n") | map(select(length > 0))')
SKILL_COUNT=$(echo "$SKILL_DIRS" | grep -c . || echo "0")
fi
echo "skills=$SKILLS_JSON" >> $GITHUB_OUTPUT
echo "count=$SKILL_COUNT" >> $GITHUB_OUTPUT
- name: Process and sync skills
if: steps.changes.outputs.count > 0
env:
ORCHESTRA_API_URL: ${{ secrets.ORCHESTRA_API_URL }}
ORCHESTRA_SYNC_API_KEY: ${{ secrets.ORCHESTRA_SYNC_API_KEY }}
run: |
SKILLS='${{ steps.changes.outputs.skills }}'
echo "Processing $(echo $SKILLS | jq 'length') skill(s)..."
# Install jq for JSON processing
sudo apt-get update && sudo apt-get install -y jq zip
# Loop through each skill directory
echo "$SKILLS" | jq -r '.[]' | while read SKILL_PATH; do
echo "==================================================="
echo "Processing: $SKILL_PATH"
echo "==================================================="
# Check if SKILL.md exists
if [ ! -f "$SKILL_PATH/SKILL.md" ]; then
echo "⚠️ WARNING: No SKILL.md found in $SKILL_PATH, skipping"
continue
fi
# Extract skill name from SKILL.md frontmatter
SKILL_NAME=$(grep -A 20 "^---$" "$SKILL_PATH/SKILL.md" | grep "^name:" | head -1 | sed 's/name: *//;s/"//g;s/'\''//g' | tr -d '\r')
# Extract author from SKILL.md frontmatter
AUTHOR=$(grep -A 20 "^---$" "$SKILL_PATH/SKILL.md" | grep "^author:" | head -1 | sed 's/author: *//;s/"//g;s/'\''//g' | tr -d '\r')
# Default values
if [ -z "$SKILL_NAME" ]; then
# Extract from directory name as fallback
SKILL_NAME=$(basename "$SKILL_PATH")
echo "⚠️ No 'name' in frontmatter, using directory name: $SKILL_NAME"
fi
if [ -z "$AUTHOR" ]; then
AUTHOR="Orchestra Research"
echo "⚠️ No 'author' in frontmatter, defaulting to: $AUTHOR"
fi
echo "Skill Name: $SKILL_NAME"
echo "Author: $AUTHOR"
echo "Path: $SKILL_PATH"
# Create temporary directory for zipping
TEMP_DIR=$(mktemp -d)
SKILL_DIR="$TEMP_DIR/$SKILL_NAME"
mkdir -p "$SKILL_DIR"
# Copy all contents of skill directory (SKILL.md, references/, scripts/, assets/, etc.)
cp -r "$SKILL_PATH"/* "$SKILL_DIR/" 2>/dev/null || true
# Create zip file (exclude hidden files and .gitkeep)
ZIP_FILE="$TEMP_DIR/${SKILL_NAME}.zip"
cd "$TEMP_DIR"
zip -r "$ZIP_FILE" "$SKILL_NAME" -x "*/.*" "*/.gitkeep" "*.DS_Store"
cd -
# Verify zip was created
if [ ! -f "$ZIP_FILE" ]; then
echo "❌ ERROR: Failed to create zip file for $SKILL_NAME"
continue
fi
echo "✓ Created zip: $(ls -lh "$ZIP_FILE" | awk '{print $5}')"
# Write SKILL.md content to temp file (avoid argument length limits)
SKILL_MD_FILE="$TEMP_DIR/skill.md"
cat "$SKILL_PATH/SKILL.md" > "$SKILL_MD_FILE"
# Encode zip to base64 and write to temp file (avoid argument length limits)
ZIP_BASE64_FILE="$TEMP_DIR/base64.txt"
base64 -w 0 "$ZIP_FILE" > "$ZIP_BASE64_FILE" 2>/dev/null || base64 "$ZIP_FILE" > "$ZIP_BASE64_FILE"
# Prepare JSON payload (use --rawfile for large content)
JSON_PAYLOAD=$(jq -n \
--arg skillName "$SKILL_NAME" \
--arg skillPath "$SKILL_PATH" \
--arg author "$AUTHOR" \
--rawfile skillMdContent "$SKILL_MD_FILE" \
--rawfile zipBase64 "$ZIP_BASE64_FILE" \
'{
skillName: $skillName,
skillPath: $skillPath,
author: $author,
skillMdContent: $skillMdContent,
zipBase64: $zipBase64
}')
# Send to Orchestra API (write JSON to file to avoid argument length limits)
echo "📤 Uploading to Orchestra..."
JSON_FILE="$TEMP_DIR/payload.json"
echo "$JSON_PAYLOAD" > "$JSON_FILE"
RESPONSE=$(curl -s -w "\n%{http_code}" -L \
-X POST \
-H "Content-Type: application/json" \
-H "X-Admin-API-Key: $ORCHESTRA_SYNC_API_KEY" \
-d @"$JSON_FILE" \
"$ORCHESTRA_API_URL/api/admin/sync-github-skill")
HTTP_CODE=$(echo "$RESPONSE" | tail -n1)
BODY=$(echo "$RESPONSE" | sed '$d')
echo "HTTP Status: $HTTP_CODE"
echo "Response: $BODY"
if [ "$HTTP_CODE" = "200" ]; then
ACTION=$(echo "$BODY" | jq -r '.action // "synced"')
SOURCE=$(echo "$BODY" | jq -r '.source // "unknown"')
echo "✅ SUCCESS: Skill $SKILL_NAME $ACTION (source: $SOURCE)"
else
ERROR_MSG=$(echo "$BODY" | jq -r '.error // "Unknown error"')
echo "❌ FAILED: $ERROR_MSG"
exit 1
fi
# Cleanup
rm -rf "$TEMP_DIR"
echo ""
done
echo "==================================================="
echo "✅ Sync completed successfully!"
echo "==================================================="
- name: No changes detected
if: steps.changes.outputs.count == 0
run: |
echo "️ No skill changes detected in this commit"
echo "Only commits that modify skill directories will trigger sync"
+103
View File
@@ -0,0 +1,103 @@
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
# LaTeX auxiliary files
*.aux
*.bbl
*.blg
*.out
*.fls
*.fdb_latexmk
*.synctex.gz
*.toc
*.lof
*.lot
*.nav
*.snm
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
*.manifest
*.spec
pip-log.txt
pip-delete-this-directory.txt
# Virtual environments
venv/
ENV/
env/
.venv
# IDEs
.vscode/
.idea/
*.swp
*.swo
*~
.DS_Store
# Jupyter Notebook
.ipynb_checkpoints
*.ipynb
# Pytest
.pytest_cache/
.coverage
htmlcov/
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# ML/Data
*.h5
*.pkl
*.pth
*.ckpt
*.safetensors
wandb/
runs/
outputs/
checkpoints/
*.log
# Environment variables
.env
.env.local
# Temporary files
tmp/
temp/
*.tmp
# Skill Seeker metadata and build artifacts
.metadata/
*_data/
!dev_data/
*_github_data.json
*_extracted.json
output/
*.zip
0-autoresearch-skill/background_docs/
0-autoresearch-skill/twitter_thread_draft.md
0-autoresearch-skill/social_posts.md
0-autoresearch-skill/image_generation_brief.md
0-autoresearch-skill/autoresearch-loops-image.png
@@ -0,0 +1,411 @@
---
name: autoresearch
description: Orchestrates end-to-end autonomous AI research projects using a two-loop architecture. The inner loop runs rapid experiment iterations with clear optimization targets. The outer loop synthesizes results, identifies patterns, and steers research direction. Routes to domain-specific skills for execution, supports continuous agent operation via Claude Code /loop and OpenClaw heartbeat, and produces research presentations and papers. Use when starting a research project, running autonomous experiments, or managing a multi-hypothesis research effort.
version: 1.0.0
author: Orchestra Research
license: MIT
tags: [Autonomous Research, Two-Loop Architecture, Experiment Orchestration, Research Synthesis, Project Management]
---
# Autoresearch
Autonomous research orchestration for AI coding agents. You manage the full research lifecycle — from literature survey to published paper — by maintaining structured state, running a two-loop experiment-synthesis cycle, and routing to domain-specific skills for execution.
You are a research project manager, not a domain expert. You orchestrate; the domain skills execute.
**This runs fully autonomously.** Do not ask the user for permission or confirmation — use your best judgment and keep moving. Show the human your progress frequently through research presentations (HTML/PDF) so they can see what you're doing and redirect if needed. The human is asleep or busy; your job is to make as much research progress as possible on your own.
## Getting Started
Users arrive in different states. Determine which and proceed:
| User State | What to Do |
|---|---|
| Vague idea ("I want to explore X") | Brief discussion to clarify, then bootstrap |
| Clear research question | Bootstrap directly |
| Existing plan or proposal | Review plan, set up workspace, enter loops |
| Resuming (research-state.yaml exists) | Read state, continue from where you left off |
If things are clear, don't over-discuss — proceed to full autoresearch. Most users want you to just start researching.
**Step 0 — before anything else**: Set up the agent continuity loop. See [Agent Continuity](#agent-continuity-mandatory--set-up-first). This is MANDATORY. Without it, the research stops after one cycle.
### Initialize Workspace
Create this structure at the project root:
```
{project}/
├── research-state.yaml # Central state tracking
├── research-log.md # Decision timeline
├── findings.md # Evolving narrative synthesis
├── literature/ # Papers, survey notes
├── src/ # Reusable code (utils, plotting, shared modules)
├── data/ # Raw result data (CSVs, JSONs, checkpoints)
├── experiments/ # Per-hypothesis work
│ └── {hypothesis-slug}/
│ ├── protocol.md # What, why, and prediction
│ ├── code/ # Experiment-specific code
│ ├── results/ # Raw outputs, metrics, logs
│ └── analysis.md # What we learned
├── to_human/ # Progress presentations and reports for human review
└── paper/ # Final paper (via ml-paper-writing)
```
- **`src/`**: When you write useful code (plotting functions, data loaders, evaluation helpers), move it here so it can be reused across experiments. Don't duplicate code in every experiment directory.
- **`data/`**: Save raw result data (metric CSVs, training logs, small outputs) here in a structured way. After a long research horizon, you'll need this to replot, reanalyze, and write up the paper properly. Name files descriptively (e.g., `trajectory_H1_runs001-010.csv`). Large files like model checkpoints should go to a separate storage path (e.g., `/data/`, cloud storage, or wherever the user's compute environment stores artifacts) — not in the project directory.
Initialize `research-state.yaml`, `research-log.md`, and `findings.md` from [templates/](templates/). Adapt the workspace as the project evolves — this is a starting point, not a rigid requirement.
## The Two-Loop Architecture
This is the core engine. Everything else supports it.
```
BOOTSTRAP (once, lightweight)
Scope question → search literature → form initial hypotheses
INNER LOOP (fast, autonomous, repeating)
Pick hypothesis → experiment → measure → record → learn → next
Goal: run constrained experiments with clear measurable outcomes
OUTER LOOP (periodic, reflective)
Review results → find patterns → update findings.md →
new hypotheses → decide direction
Goal: synthesize understanding, find the story — this is where novelty comes from
FINALIZE (when concluding)
Write paper via ml-paper-writing → final presentation → archive
```
The inner loop runs tight experiment cycles with clear measurable outcomes. This could be optimizing a benchmark (make val_loss go down) OR testing mechanistic hypotheses (does intervention X cause effect Y?). The outer loop steps back to ask: what do these results *mean*? What patterns emerge? What's the story? Research is open-ended — the two loops let you both optimize and discover.
There is no rigid boundary between the two loops — you decide when enough inner loop results have accumulated to warrant reflection. Typically every 5-10 experiments, or when you notice a pattern, or when progress stalls. The agent's judgment drives the rhythm.
### Research is Non-Linear
The two-loop structure is a rhythm, not a railroad. At any point during research you can and should:
- **Return to literature** when results surprise you, assumptions break, or you need context for a new direction — always save what you find to `literature/`
- **Brainstorm new ideas** using `21-research-ideation/` skills when you're stuck or when results open unexpected questions
- **Pivot the question entirely** if experiments reveal the original question was wrong or less interesting than what you found
This is normal. Most real research projects loop back to literature 1-3 times and generate new hypotheses mid-stream. Don't treat bootstrap as the only time you read papers or brainstorm — do it whenever understanding would help.
## Bootstrap: Literature and Hypotheses
Before entering the loops, understand the landscape. Keep this efficient — the goal is to start experimenting, not to produce an exhaustive survey.
1. **Search literature** for the research question. Use multiple sources — never stop at one:
- **Exa MCP** (`web_search_exa`) if available — best for broad discovery and finding relevant papers quickly
- **Semantic Scholar** (`pip install semanticscholar`) — best for ML/AI papers, citation graphs, and specific paper lookup. See `20-ml-paper-writing` skill's `references/citation-workflow.md` for complete API code examples
- **arXiv** (`pip install arxiv`) — best for recent preprints and open-access papers
- **CrossRef** — best for DOI lookup and BibTeX retrieval
- Keep searching until you have good coverage. If one source comes up empty, try another with different keywords
**Save everything to `literature/`**: For every paper you find, save a summary to `literature/` — title, authors, year, key findings, relevance to your question, and the URL/DOI. Create one file per paper and a running `literature/survey.md` with all summaries. This is your reference library — you and future sessions will need it throughout the project.
2. **Identify gaps** from the literature
- What's been tried? What hasn't? Where do existing methods break?
- What do Discussion sections flag as future work?
3. **Form initial hypotheses** — invoke `21-research-ideation/` skills
- `brainstorming-research-ideas` for structured diverge-converge workflow
- `creative-thinking-for-research` for deeper cognitive frameworks
- Each hypothesis must be testable with a clear prediction
4. **Define the evaluation**
- Set the proxy metric and baseline before running experiments
- The metric should be computable quickly (minutes, not hours)
- Lock evaluation criteria upfront to prevent unconscious metric gaming
5. **Record** in research-state.yaml, log the bootstrap in research-log.md
## The Inner Loop
Rapid iteration with clear measurable outcomes. Two flavors:
- **Optimization**: make a metric go up/down (val_loss, accuracy, throughput). Think Karpathy's autoresearch.
- **Discovery**: test mechanistic hypotheses about why something works. The metric is a measurement (does grokking happen faster? does entropy increase before forgetting?), not just a target to optimize.
```
1. Pick the highest-priority untested hypothesis
2. Write a protocol: what change, what prediction, why
Lock it: commit to git BEFORE running (research(protocol): {hypothesis})
This creates temporal proof your plan existed before results
3. Run the experiment (invoke the relevant domain skill)
4. Sanity check before trusting results:
- Did training converge? No NaN/Inf?
- Does baseline reproduce expected performance?
- Data loading correct? (spot-check a few samples)
5. Measure the proxy metric
6. Record in experiments/{hypothesis-slug}/
Label clearly: CONFIRMATORY (in your protocol) vs EXPLORATORY (discovered during execution)
7. If positive: keep, note WHY it worked
8. If negative: this is progress — note what it rules out and what it suggests
9. Update research-state.yaml
10. If stuck: search literature or invoke ideation skills — don't just keep trying random things
```
**Never stop.** Even if something fails, find a path forward. Debug, adjust, simplify, or pivot — but keep the research moving. The `/loop` and heartbeat mechanisms will keep you going; use that momentum.
### Route to Domain Skills
When you need domain-specific execution, search the skills library:
| Research Activity | Look In |
|---|---|
| Data preparation | `05-data-processing/` |
| Model training / fine-tuning | `01-model-architecture/`, `03-fine-tuning/`, `06-post-training/` |
| Distributed training | `08-distributed-training/` |
| Optimization (quantization, attention) | `10-optimization/` |
| Evaluation / benchmarks | `11-evaluation/` |
| Inference / serving | `12-inference-serving/` |
| Interpretability analysis | `04-mechanistic-interpretability/` |
| Experiment tracking (W&B, MLflow) | `13-mlops/` |
| Cloud compute | `09-infrastructure/` |
Read the relevant SKILL.md before starting — it has workflows, common issues, and code examples. See [references/skill-routing.md](references/skill-routing.md) for a complete guide.
### Track the Experiment Trajectory
Maintain a running record of measurable outcomes across experiments:
```json
{
"experiment_id": "run_014",
"hypothesis": "H3",
"metric_value": 0.847,
"baseline": 0.812,
"delta": "+0.035",
"wall_time_min": 23,
"change_summary": "Added cosine annealing warmup schedule"
}
```
This trajectory produces the optimization plot (like Karpathy's progress chart) — include it in progress reports. Humans love seeing the upward curve.
## The Outer Loop
Step back from individual experiments. Synthesize.
```
1. Review all results since last reflection
2. Cluster by type: what kinds of changes worked? Which didn't?
3. Ask WHY — identify the mechanism behind successes and failures
4. Update findings.md with current understanding
5. Search literature if results were surprising or assumptions need revisiting
6. Generate new hypotheses if warranted (invoke 21-research-ideation/ skills)
7. Decide direction (see criteria below)
8. Update research-state.yaml with new direction
9. Log the reflection in research-log.md
10. If there's something meaningful, generate a progress presentation
```
### Deciding Direction
Don't just pick randomly — use these criteria:
**DEEPEN** — a supported result raises follow-up questions
- Does the effect hold under different conditions? What's the mechanism?
- Action: generate sub-hypotheses (H1.1, H1.2) → back to inner loop
**BROADEN** — current results are solid, but adjacent questions are untested
- New questions emerged. The current contribution is clear but more is possible.
- Action: generate new root hypotheses → back to inner loop
**PIVOT** — results invalidate key assumptions or something more interesting appeared
- A core assumption was wrong, or an unexpected finding is more promising than the original question.
- Action: return to literature with new questions → re-bootstrap
**CONCLUDE** — sufficient evidence for a contribution
- At least one hypothesis is strongly supported (or a coherent set of negative results)
- Key ablations completed, error analysis done
- findings.md reads like a paper backbone — a human could write the abstract from it
- No critical open questions that would change the story
Note: coherent negative results are a valid contribution. "X does NOT work because Y" is publishable if the reasoning is rigorous.
### findings.md Is Your Project Memory
This file serves two purposes: it's the research narrative for humans AND your accumulated knowledge base as an agent. Read it at the start of every session, /loop tick, or heartbeat to remember what you've learned.
After every outer loop, update it to answer:
- What do we know so far? (Current Understanding)
- What patterns explain our results? (Patterns and Insights)
- What specific things did we learn not to repeat? (Lessons and Constraints)
- What remains open? (Open Questions)
The "Lessons and Constraints" section is especially important — it captures specific actionable learnings like "weight decay > 0.1 diverges at this scale" or "baseline only reproduces with batch_size=64." This prevents the agent from repeating failed approaches across sessions.
**Quality test**: After 30 inner loop experiments, a human should be able to read findings.md and write a paper abstract from it. If they can't, the outer loop isn't synthesizing — it's just logging.
## Agent Continuity (MANDATORY — Set Up First)
**Before doing anything else**, set up the wall-clock loop. This is what keeps the research running continuously. Without it, the agent stops after one cycle.
### Claude Code
**First action — run this immediately:**
```
/loop 20m Continue autoresearch. Read research-state.yaml and findings.md. Re-read the autoresearch SKILL.md occasionally to stay aligned. Step back and reflect holistically — is the research making real progress? Are you deepening understanding or just running experiments? If stalling, pivot or search literature for new ideas. Keep making research progress — never idle, never stop. Update findings.md, research-log.md, and research-state.yaml when there's new progress. Git commit periodically and clean up the repo if needed. Show the human your research progress with key plots and findings by preparing a report in to_human/ and opening the HTML/PDF. Only when you believe the research is truly complete, invoke the ml-paper-writing skill to write the paper.
```
This fires every 20 minutes regardless of what's happening. It's a rhythm that keeps you working — not a research phase boundary. If your previous work isn't done, just continue it.
### OpenClaw
**First action — set up a 20-minute cron job in the current session. You MUST do this before any research work.**
Use the `cron.add` tool to create a recurring job bound to this chat session:
```json
{
"name": "autoresearch-loop",
"schedule": { "kind": "every", "everyMs": 1200000 },
"sessionTarget": "current",
"payload": {
"kind": "agentTurn",
"message": "Continue autoresearch. Read research-state.yaml and findings.md. Re-read the autoresearch SKILL.md occasionally to stay aligned. Step back and reflect holistically — is the research making real progress? Are you deepening understanding or just running experiments? If stalling, pivot or search literature for new ideas. Keep making research progress — never idle, never stop. Update findings.md, research-log.md, and research-state.yaml when there's new progress. Git commit periodically and clean up the repo if needed. Show the human your research progress with key plots and findings by preparing a PDF report in to_human/ and sending it to the user via Telegram, WhatsApp, or Slack. When you get an exciting result or interesting plot, text it to the user right away — don't wait for a full report. Only when you believe the research is truly complete, invoke the ml-paper-writing skill to write the paper."
}
}
```
Key details:
- `sessionTarget: "current"` binds the cron to this chat session so it maintains conversation context across ticks
- `everyMs: 1200000` = 20 minutes
- After creation, verify with `cron.list` that the job exists and is enabled
- If the cron fires while you're mid-experiment, just continue — the tick is a nudge, not a restart
### What the Loop Does
The `/loop` and cron job are purely **wall-clock rhythm**. They are completely separate from your research loops (inner/outer). On each tick:
1. Read `research-state.yaml` and `findings.md` — remember where you are
2. Check if anything is broken (failed experiments, stalled training, errors)
3. If on track → keep working on whatever you were doing
4. If stuck or something's wrong → step back, diagnose, fix, then continue
5. Never idle. Always be making progress.
## Progress Reporting
When you have something meaningful to share, create a research presentation — not just a status dashboard, but a compelling story.
**When to report** (your judgment):
- After an outer loop that found a significant pattern
- When the optimization trajectory shows clear progress (include the plot!)
- After a pivot in direction
- Before requesting human input on a decision
- When concluding
**What to include** (adapt to what's compelling):
- The research question and why it matters
- Key results with visualizations (plots, metric tables)
- The optimization trajectory chart (metric over experiments)
- What was tried and why (selective, not exhaustive)
- Current understanding (the findings narrative)
- What's planned next
For Claude Code: generate HTML and `open` it. If HTML fails to open or render, convert to PDF as fallback (use `weasyprint`, `playwright pdf`, or `wkhtmltopdf`). For OpenClaw: generate PDF directly.
See [references/progress-reporting.md](references/progress-reporting.md) for template scaffolding and the optimization plot approach. Use the template as a starting point — be creative with what you show.
## Git Protocol
Commit at natural research milestones:
| When | Message Pattern |
|---|---|
| Workspace initialized | `research(init): {project} — {question}` |
| Experiment protocol locked | `research(protocol): {hypothesis}` |
| Significant results | `research(results): {hypothesis} — {outcome}` |
| Outer loop direction change | `research(reflect): {direction} — {reason}` |
| Paper draft complete | `research(paper): {title}` |
**Hard rule**: Protocol commits MUST precede result commits. Never combine them. The git history is your lightweight pre-registration — it proves what you planned before you saw results. Don't commit after every experiment — commit when there's meaningful progress.
## Concluding: Paper Writing
When the outer loop decides to CONCLUDE:
1. Ensure findings.md has a clear, well-supported narrative
2. Study 2-3 top related papers to learn their format, style, and section structure
3. Invoke the `20-ml-paper-writing` skill — it has LaTeX templates for NeurIPS, ICML, ICLR, ACL, AAAI, COLM, and systems venues
4. Feed it the accumulated literature, experimental results, and findings
5. Follow its citation verification workflow — never hallucinate references
6. Generate a final comprehensive research presentation
Proceed autonomously through the writing process. If the ml-paper-writing skill suggests human collaboration points, adapt and keep going — produce the best draft you can. The human will review and provide feedback.
## Research Discipline
Principles to enforce continuously — not tied to any specific phase:
- **Lock before you run**: Commit your experiment protocol to git before executing. This proves your plan existed before you saw results. Never combine protocol + results in one commit.
- **Confirmatory vs exploratory**: Results matching your locked protocol are confirmatory. Everything else is exploratory — interesting but requiring more skepticism.
- **Negative results are progress**: A refuted hypothesis tells you something. Log what it rules out and what it suggests. Don't treat it as failure.
- **Sanity check before analysis**: Verify training converged, baselines reproduce, and data is correct before trusting your primary metric.
- **Return to literature when confused**: Don't guess — search. If results surprise you or assumptions break, go find papers. Use Exa MCP for discovery, Semantic Scholar for specific ML/AI paper lookup, arXiv for preprints.
- **Never stop**: Don't wait for human approval on routine decisions. If a skill or tool suggests collaboration, adapt and keep going. Find the best path forward autonomously. The human will see your progress reports and can redirect if needed.
- **Use whatever compute is available**: Adapt to the user's environment — local GPU, cluster job submission, cloud instances, or just CPU. If no GPU is available, use CPU and adjust experiment scale accordingly. Don't block on compute availability.
## Quality Standards
**Good agent behavior:**
- Hypotheses have mechanistic reasoning ("X because Y, predicting Z"), not just "try X"
- findings.md builds a coherent narrative, not a flat list of results
- Negative results are recorded with what they rule out
- The agent updates its model when experiments contradict expectations
- Progress reports tell a research story with compelling visualizations
**Bad agent behavior:**
- Pure hyperparameter sweeps without interpretation
- findings.md is just experiment logs copy-pasted
- Agent never revisits its assumptions after failures
- Optimizing metrics without understanding why changes work
## When to Use vs Alternatives
**Use autoresearch when:**
- You have a research question explorable through experiments
- There's a measurable proxy metric for inner loop optimization
- The real contribution requires synthesis beyond the metric
- You want continuous autonomous research operation
**Use individual domain skills instead when:**
- You have a specific one-off task (train a model, run eval, write a paper)
- No iterative experimentation needed
## Common Issues
**Inner loop stalls (no metric improvement)**
Run an outer loop. Is the metric the right one? Is the search space exhausted? Consider broadening or pivoting. Search literature for new approaches.
**Stuck and not making progress**
Don't keep trying random changes. Step back: search literature for related work, invoke `21-research-ideation/` brainstorming skills, or run an outer loop reflection. Being stuck means you need new information or a new perspective, not more experiments.
**Results contradict baseline expectations**
Investigate, don't ignore. Return to literature — your protocol might have an error, the published baseline may be wrong, or conditions differ. Update findings.md with what you learn.
**Agent loses context between ticks**
Ensure research-state.yaml and findings.md are updated after every action. These files are your memory across sessions.
**Can't find relevant papers**
Try multiple approaches in order: Exa MCP for broad search, Semantic Scholar for specific ML/AI paper lookup (`pip install semanticscholar`), arXiv for preprints (`pip install arxiv`). Check `20-ml-paper-writing` skill's `references/citation-workflow.md` for complete API code. Note: Google Scholar has no official API — use Semantic Scholar instead for programmatic search.
**No GPU available**
Use CPU and scale experiments down. Many research tasks (analysis, interpretability, small model training) run fine on CPU. Adjust experiment design to fit available compute rather than blocking.
**Experiments take longer than /loop interval**
Normal. On the next tick, check if it finished. If not, keep waiting or do something else useful (update notes, search papers). Adjust interval if needed.
**Not sure when to conclude**
Three questions: Do you have a strongly supported finding? Can you explain WHY it works? Would findings.md make a convincing paper abstract? If yes to all: conclude.
## Advanced Topics
- **Detailed agent continuity**: [references/agent-continuity.md](references/agent-continuity.md)
- **Progress presentation templates**: [references/progress-reporting.md](references/progress-reporting.md)
- **Complete skill routing**: [references/skill-routing.md](references/skill-routing.md)
@@ -0,0 +1,94 @@
# Agent Continuity: Keeping Research Running
Autonomous research requires agents that keep working continuously — hours, overnight, or days. This reference covers how to set up the wall-clock loop on each platform.
## MANDATORY: Set Up Before Anything Else
The wall-clock loop is the first thing you do when starting autoresearch. Without it, the agent does one cycle of work and stops. The loop fires every 20 minutes and simply tells the agent: "keep working, check if anything's wrong."
This is completely separate from the research inner/outer loops. The wall-clock loop is a fixed-interval prompt injection. The research loops run at experiment speed (minutes to hours per cycle). They will not sync and don't need to.
## Claude Code: /loop
**Run this as your very first action:**
```
/loop 20m Continue autoresearch. Read research-state.yaml and findings.md. Check: is the current experiment done? Are there errors? Is progress stalling? If everything is on track, keep working. If something is wrong, step back and fix it. Never idle.
```
### What Happens on Each Tick
The `/loop` message fires every 20 minutes. When it fires:
1. **Read state**`research-state.yaml` and `findings.md` to remember where you are and what you've learned
2. **Check health** — is the current experiment running? Did it error? Is training diverging?
3. **If on track** → continue whatever you were doing (don't restart or re-plan)
4. **If stuck or broken** → step back, diagnose the issue, fix it, then continue
5. **If idle** → you shouldn't be. Pick the next hypothesis and start an experiment
The tick is a nudge, not a phase boundary. If the previous tick's work isn't done, just keep going.
### When Experiments Run Longer Than 20 Minutes
This is normal and expected. On the next tick:
- Check if the experiment finished
- If still running: monitor for errors, do lightweight work (update notes, search papers)
- Don't restart or duplicate the experiment
### Context Recovery
Your primary memory between ticks is `research-state.yaml` and `findings.md`. Keep them updated after every experiment and reflection. If you lose context, read both files plus `research-log.md` to reconstruct where you are.
## OpenClaw: Cron Job
**Set up a 20-minute cron job in the current session as your very first action.** Use the `cron.add` tool to create a recurring job bound to the current chat session:
```json
{
"name": "autoresearch-loop",
"schedule": { "kind": "every", "everyMs": 1200000 },
"sessionTarget": "current",
"payload": {
"kind": "agentTurn",
"message": "Continue autoresearch. Read research-state.yaml and findings.md. Check: is the current experiment done? Are there errors? Is progress stalling? If everything is on track, keep working. If something is wrong, step back and fix it. Never idle."
}
}
```
Key details:
- `sessionTarget: "current"` binds the cron to the current chat session (resolved to `session:<sessionKey>` at creation time), so it maintains conversation context across ticks
- `everyMs: 1200000` = 20 minutes
- Verify with `cron.list` that the job is created and enabled
- To check run history later: `cron.runs` with the job ID
### Context Between Cron Ticks
OpenClaw cron invocations may start fresh each time. Your workspace files are your memory:
- `research-state.yaml` — where you are, what's active
- `findings.md` — what you've learned (read this every time!)
- `research-log.md` — what happened chronologically
Keep these updated after every action so the next cron tick can pick up seamlessly.
### Progress Reports
OpenClaw can't `open` HTML files locally like Claude Code can. When you have something to report:
1. Generate a PDF progress summary (use Python with reportlab, matplotlib, or similar)
2. Include: research question, key results, optimization trajectory plot, current understanding, next steps
3. Send it to the user via Telegram, WhatsApp, or Slack — whichever channel they use
4. When you get an exciting result or interesting plot, send it right away — don't wait for a full report
## Research State as Ground Truth
Both platforms share the same ground truth: the workspace files.
| File | Purpose | Update Frequency |
|---|---|---|
| `research-state.yaml` | Machine-readable state | After every experiment and reflection |
| `research-log.md` | Decision timeline | After every significant action |
| `findings.md` | Narrative understanding + project memory | After every outer loop |
| `experiments/*/results/` | Raw experimental data | After every experiment |
The wall-clock loop (`/loop` or cron) is just the trigger. The workspace files are the memory. Keep them current.
@@ -0,0 +1,165 @@
# Progress Reporting: Research Presentations
When the research produces something worth sharing, create a compelling presentation — not a status dump, but a research story with visuals.
## When to Report
You decide when progress is meaningful enough to report. Consider reporting:
- After an outer loop reflection that identified a significant pattern
- When the optimization trajectory shows clear, sustained improvement
- After a pivot — explain why the direction changed
- Before requesting human input on a major decision
- When concluding the research, before paper writing
Maximum frequency: once per /loop tick or heartbeat cycle. Minimum: whenever you have something a human would find interesting.
## What Makes a Good Research Presentation
A good progress report reads like a research talk, not a database query. It should:
1. **Tell a story**: why we started, what we tried, what we found, what it means
2. **Show, don't just tell**: include plots, tables, comparisons — not just text
3. **Be selective**: highlight the interesting findings, don't exhaustively list every experiment
4. **End with direction**: what happens next and why
## Recommended Sections
Adapt these to what's compelling from your current research. Skip sections that aren't relevant. Add sections the research demands.
### 1. Research Question and Motivation
- What are we investigating and why does it matter?
- One paragraph, accessible to someone unfamiliar with the project
### 2. Approach
- What's our method? What are we optimizing?
- The two-loop architecture in one sentence
### 3. Optimization Trajectory (The Karpathy Plot)
- X-axis: experiment number or wall-clock time
- Y-axis: proxy metric value
- Show baseline as a horizontal line
- Annotate significant jumps with what change caused them
- This is often the most compelling visual — include it whenever possible
### 4. Key Findings
- The 2-3 most significant results with supporting evidence
- Include plots, metric tables, comparison charts
- Explain WHY results are significant, not just WHAT they are
### 5. What We Tried (Decision Map)
- A selective view of the hypothesis tree
- Focus on the reasoning: why each direction was chosen, what it taught us
- Include both successes and informative failures
### 6. Current Understanding
- The findings.md narrative, but presented compellingly
- What's our best explanation for the patterns we see?
### 7. Next Steps
- What experiments are planned and why
- What questions remain open
- Any decisions that need human input
## The Optimization Trajectory Plot
This is the signature visual of autoresearch — a chart showing metric improvement over experiments.
Minimal implementation (SVG-based, no dependencies):
```python
def generate_trajectory_svg(trajectory_data, width=800, height=400):
"""Generate an SVG optimization trajectory chart.
trajectory_data: list of {"run": int, "metric": float, "label": str}
"""
if not trajectory_data:
return "<p>No experiments yet.</p>"
metrics = [d["metric"] for d in trajectory_data]
min_m, max_m = min(metrics), max(metrics)
margin = (max_m - min_m) * 0.1 or 0.1
y_min, y_max = min_m - margin, max_m + margin
padding = 60
plot_w = width - 2 * padding
plot_h = height - 2 * padding
n = len(trajectory_data)
def x_pos(i):
return padding + (i / max(n - 1, 1)) * plot_w
def y_pos(v):
return padding + plot_h - ((v - y_min) / (y_max - y_min)) * plot_h
# Build SVG
svg = f'<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">'
svg += f'<rect width="{width}" height="{height}" fill="#1a1a2e" rx="8"/>'
# Grid lines
for i in range(5):
y = padding + i * plot_h / 4
val = y_max - i * (y_max - y_min) / 4
svg += f'<line x1="{padding}" y1="{y}" x2="{width-padding}" y2="{y}" stroke="#333" stroke-dasharray="4"/>'
svg += f'<text x="{padding-8}" y="{y+4}" fill="#888" text-anchor="end" font-size="11">{val:.3f}</text>'
# Baseline line
baseline = trajectory_data[0]["metric"]
by = y_pos(baseline)
svg += f'<line x1="{padding}" y1="{by}" x2="{width-padding}" y2="{by}" stroke="#ff6b6b" stroke-dasharray="6" opacity="0.7"/>'
svg += f'<text x="{width-padding+5}" y="{by+4}" fill="#ff6b6b" font-size="10">baseline</text>'
# Data line
points = " ".join(f"{x_pos(i)},{y_pos(d['metric'])}" for i, d in enumerate(trajectory_data))
svg += f'<polyline points="{points}" fill="none" stroke="#4ecdc4" stroke-width="2"/>'
# Data points
for i, d in enumerate(trajectory_data):
cx, cy = x_pos(i), y_pos(d["metric"])
svg += f'<circle cx="{cx}" cy="{cy}" r="4" fill="#4ecdc4"/>'
# Title
svg += f'<text x="{width/2}" y="24" fill="#eee" text-anchor="middle" font-size="14" font-weight="bold">Optimization Trajectory</text>'
svg += f'<text x="{width/2}" y="{height-10}" fill="#888" text-anchor="middle" font-size="11">Experiment Run</text>'
svg += '</svg>'
return svg
```
Embed the SVG output directly in the HTML report. Annotate significant jumps with brief labels.
## HTML Presentation Template
Use [templates/progress-presentation.html](../templates/progress-presentation.html) as a starting point. It provides:
- Clean, dark-themed styling suitable for research presentations
- Responsive layout
- Section scaffolding matching the recommended structure
- Placeholder for the trajectory chart
Replace placeholder content with your actual research data. Add, remove, or rearrange sections as the research demands. The template is a scaffold, not a constraint.
### Claude Code
Generate the HTML, then show it to the human:
```bash
open to_human/progress-001.html
```
### OpenClaw
Generate a PDF version. Options:
- Use Python `weasyprint` to convert HTML to PDF
- Use `matplotlib` to generate plots directly as PDF
- Create a simple markdown → PDF pipeline
Note the PDF path in HEARTBEAT.md so the human knows to look at it.
## Presentation Quality Tips
- **One insight per section** — don't overload
- **Label axes and units** on all plots
- **Use color consistently** — one color for improvements, another for baselines
- **Include confidence intervals** or error bars where meaningful
- **Show the trajectory early** — it's the hook that tells the reader "this is working"
- **End with a clear next step** — the human should know what happens next without asking
@@ -0,0 +1,218 @@
# Skill Routing: When to Use Which Domain Skill
The autoresearch skill orchestrates — domain skills execute. This reference maps research activities to the skills library.
## Routing Principle
When you encounter a domain-specific task during research, search the skills library for the right tool. Read the SKILL.md of the relevant skill before starting — it contains workflows, common issues, and production-ready code examples.
## Complete Routing Map
### Data and Preprocessing
| Task | Skill | Location |
|---|---|---|
| Large-scale data processing | Ray Data | `05-data-processing/ray-data/` |
| Data curation and filtering | NeMo Curator | `05-data-processing/nemo-curator/` |
| Custom tokenizer training | HuggingFace Tokenizers | `02-tokenization/hf-tokenizers/` |
| Subword tokenization | SentencePiece | `02-tokenization/sentencepiece/` |
### Model Architecture and Training
| Task | Skill | Location |
|---|---|---|
| Large-scale pretraining | Megatron-Core | `01-model-architecture/megatron-core/` |
| Lightweight LLM training | LitGPT | `01-model-architecture/litgpt/` |
| State-space models | Mamba | `01-model-architecture/mamba/` |
| Linear attention models | RWKV | `01-model-architecture/rwkv/` |
| Small-scale pretraining | NanoGPT | `01-model-architecture/nanogpt/` |
### Fine-tuning
| Task | Skill | Location |
|---|---|---|
| Multi-method fine-tuning | Axolotl | `03-fine-tuning/axolotl/` |
| Template-based fine-tuning | LLaMA-Factory | `03-fine-tuning/llama-factory/` |
| Fast LoRA fine-tuning | Unsloth | `03-fine-tuning/unsloth/` |
| PyTorch-native fine-tuning | Torchtune | `03-fine-tuning/torchtune/` |
### Post-training (RL / Alignment)
| Task | Skill | Location |
|---|---|---|
| PPO, DPO, SFT pipelines | TRL | `06-post-training/trl/` |
| Group Relative Policy Optimization | GRPO | `06-post-training/grpo-rl-training/` |
| Scalable RLHF | OpenRLHF | `06-post-training/openrlhf/` |
| Reference-free alignment | SimPO | `06-post-training/simpo/` |
### Interpretability
| Task | Skill | Location |
|---|---|---|
| Transformer circuit analysis | TransformerLens | `04-mechanistic-interpretability/transformerlens/` |
| Sparse autoencoder training | SAELens | `04-mechanistic-interpretability/saelens/` |
| Intervention experiments | NNsight | `04-mechanistic-interpretability/nnsight/` |
| Causal tracing | Pyvene | `04-mechanistic-interpretability/pyvene/` |
### Distributed Training
| Task | Skill | Location |
|---|---|---|
| ZeRO optimization | DeepSpeed | `08-distributed-training/deepspeed/` |
| Fully sharded data parallel | FSDP | `08-distributed-training/fsdp/` |
| Multi-GPU abstraction | Accelerate | `08-distributed-training/accelerate/` |
| Training framework | PyTorch Lightning | `08-distributed-training/pytorch-lightning/` |
| Distributed data + training | Ray Train | `08-distributed-training/ray-train/` |
### Evaluation
| Task | Skill | Location |
|---|---|---|
| Standard LLM benchmarks | lm-evaluation-harness | `11-evaluation/lm-eval-harness/` |
| NeMo-integrated evaluation | NeMo Evaluator | `11-evaluation/nemo-evaluator/` |
| Custom eval tasks | Inspect AI | `11-evaluation/inspect-ai/` |
### Inference and Serving
| Task | Skill | Location |
|---|---|---|
| High-throughput serving | vLLM | `12-inference-serving/vllm/` |
| NVIDIA-optimized inference | TensorRT-LLM | `12-inference-serving/tensorrt-llm/` |
| CPU / edge inference | llama.cpp | `12-inference-serving/llama-cpp/` |
| Structured generation serving | SGLang | `12-inference-serving/sglang/` |
### Experiment Tracking
| Task | Skill | Location |
|---|---|---|
| Full experiment tracking | Weights & Biases | `13-mlops/wandb/` |
| Open-source tracking | MLflow | `13-mlops/mlflow/` |
| Training visualization | TensorBoard | `13-mlops/tensorboard/` |
### Optimization Techniques
| Task | Skill | Location |
|---|---|---|
| Efficient attention | Flash Attention | `10-optimization/flash-attention/` |
| 4/8-bit quantization | bitsandbytes | `10-optimization/bitsandbytes/` |
| GPTQ quantization | GPTQ | `10-optimization/gptq/` |
| AWQ quantization | AWQ | `10-optimization/awq/` |
| GGUF format (llama.cpp) | GGUF | `10-optimization/gguf/` |
| PyTorch-native quantization | Quanto | `10-optimization/quanto/` |
### Safety and Alignment
| Task | Skill | Location |
|---|---|---|
| Constitutional AI training | Constitutional AI | `07-safety-alignment/constitutional-ai/` |
| Content safety classification | LlamaGuard | `07-safety-alignment/llamaguard/` |
| Guardrail pipelines | NeMo Guardrails | `07-safety-alignment/nemo-guardrails/` |
| Prompt injection detection | Prompt Guard | `07-safety-alignment/prompt-guard/` |
### Infrastructure
| Task | Skill | Location |
|---|---|---|
| Serverless GPU compute | Modal | `09-infrastructure/modal/` |
| Multi-cloud orchestration | SkyPilot | `09-infrastructure/skypilot/` |
| GPU cloud instances | Lambda Labs | `09-infrastructure/lambda-labs/` |
### Agents and RAG
| Task | Skill | Location |
|---|---|---|
| Agent pipelines | LangChain | `14-agents/langchain/` |
| Knowledge retrieval agents | LlamaIndex | `14-agents/llamaindex/` |
| Lightweight agents | Smolagents | `14-agents/smolagents/` |
| Claude-based agents | Claude Agent SDK | `14-agents/claude-agent-sdk/` |
| Vector store (local) | Chroma | `15-rag/chroma/` |
| Vector similarity search | FAISS | `15-rag/faiss/` |
| Text embeddings | Sentence Transformers | `15-rag/sentence-transformers/` |
| Managed vector DB | Pinecone | `15-rag/pinecone/` |
| Scalable vector DB | Milvus | `15-rag/milvus/` |
### Prompt Engineering and Structured Output
| Task | Skill | Location |
|---|---|---|
| Prompt optimization | DSPy | `16-prompt-engineering/dspy/` |
| Structured LLM output | Instructor | `16-prompt-engineering/instructor/` |
| Constrained generation | Guidance | `16-prompt-engineering/guidance/` |
| Grammar-based generation | Outlines | `16-prompt-engineering/outlines/` |
### Multimodal
| Task | Skill | Location |
|---|---|---|
| Vision-language models | CLIP | `18-multimodal/clip/` |
| Speech recognition | Whisper | `18-multimodal/whisper/` |
| Visual instruction tuning | LLaVA | `18-multimodal/llava/` |
| Vision-language (Qwen) | Qwen2-VL | `18-multimodal/qwen2-vl/` |
| Vision-language (Mistral) | Pixtral | `18-multimodal/pixtral/` |
| Visual understanding | Florence-2 | `18-multimodal/florence-2/` |
| Document retrieval | ColPali | `18-multimodal/colpali/` |
### Observability
| Task | Skill | Location |
|---|---|---|
| LLM tracing and debugging | LangSmith | `17-observability/langsmith/` |
| LLM observability platform | Phoenix | `17-observability/phoenix/` |
### Emerging Techniques
| Task | Skill | Location |
|---|---|---|
| Mixture of Experts training | MoE Training | `19-emerging-techniques/moe-training/` |
| Combining trained models | Model Merging | `19-emerging-techniques/model-merging/` |
| Extended context windows | Long Context | `19-emerging-techniques/long-context/` |
| Faster inference via drafting | Speculative Decoding | `19-emerging-techniques/speculative-decoding/` |
| Teacher-student compression | Knowledge Distillation | `19-emerging-techniques/knowledge-distillation/` |
| Reducing model size | Model Pruning | `19-emerging-techniques/model-pruning/` |
### Research Output
| Task | Skill | Location |
|---|---|---|
| Generate research ideas | Research Ideation | `21-research-ideation/` |
| Write publication-ready paper | ML Paper Writing | `20-ml-paper-writing/` |
## Common Research Workflows
### "I need to fine-tune a model and evaluate it"
1. Pick fine-tuning skill based on needs (Unsloth for speed, Axolotl for flexibility)
2. Use lm-evaluation-harness for standard benchmarks
3. Track with W&B or MLflow
### "I need to understand what the model learned"
1. Use TransformerLens for circuit-level analysis
2. Train SAEs with SAELens for feature-level understanding
3. Run interventions with NNsight or Pyvene
### "I need to do RL training"
1. Start with TRL for standard PPO/DPO
2. Use GRPO skill for DeepSeek-R1 style training
3. Scale with OpenRLHF if needed
### "I need to run experiments on cloud GPUs"
1. Modal for quick serverless runs
2. SkyPilot for multi-cloud optimization
3. Lambda Labs for dedicated instances
## Finding Skills
If you're not sure which skill to use:
```bash
# Search by keyword in skill names
ls */*/SKILL.md | head -20
# Search skill descriptions for a keyword
grep -l "keyword" */*/SKILL.md
```
Or search the repository's README.md which lists all skills with descriptions.
@@ -0,0 +1,43 @@
# Research Findings
## Research Question
<!-- What are we trying to discover? One clear sentence. -->
## Current Understanding
<!-- Updated after each outer loop cycle. What do we know so far?
What patterns explain our results? What's the mechanism?
This section should read like the core argument of a paper. -->
## Key Results
<!-- Significant experimental findings. Include metrics, comparisons, and
brief interpretation. Link to experiment directories for full details. -->
## Patterns and Insights
<!-- What emerges across multiple experiments? What types of changes
consistently work or fail? Why? -->
## Lessons and Constraints
<!-- Specific actionable learnings that should guide future experiments.
Things you tried that didn't work and WHY, so you don't repeat them.
Constraints you discovered about the problem space.
Examples:
- Weight decay > 0.1 causes training instability at 125M param scale
- SwiGLU and RoPE improvements stack because they're orthogonal (FFN vs positional)
- Baseline only reproduces published numbers with batch_size=64, not 32
- Sleep phases before memorization completion hurt — model needs memories to consolidate -->
## Open Questions
<!-- What remains unanswered? What would strengthen or challenge
our current understanding? -->
## Optimization Trajectory
<!-- Summary of inner loop progress. How has the metric evolved?
Note inflection points and what caused them. -->
@@ -0,0 +1,306 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Research Progress</title>
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', system-ui, sans-serif;
background: #0d1117;
color: #e6edf3;
line-height: 1.6;
padding: 2rem;
max-width: 1100px;
margin: 0 auto;
}
header {
text-align: center;
padding: 3rem 0 2rem;
border-bottom: 1px solid #21262d;
margin-bottom: 2.5rem;
}
header h1 {
font-size: 2.2rem;
font-weight: 700;
color: #f0f6fc;
margin-bottom: 0.5rem;
}
.subtitle {
font-size: 1.15rem;
color: #8b949e;
font-style: italic;
max-width: 700px;
margin: 0 auto 1rem;
}
.meta {
font-size: 0.85rem;
color: #484f58;
}
.meta span {
display: inline-block;
margin: 0 0.5rem;
padding: 0.15rem 0.6rem;
background: #161b22;
border: 1px solid #21262d;
border-radius: 12px;
}
section {
margin-bottom: 3rem;
}
section h2 {
font-size: 1.4rem;
font-weight: 600;
color: #f0f6fc;
margin-bottom: 1rem;
padding-bottom: 0.5rem;
border-bottom: 1px solid #21262d;
}
p, li { color: #c9d1d9; }
.card {
background: #161b22;
border: 1px solid #21262d;
border-radius: 8px;
padding: 1.5rem;
margin-bottom: 1rem;
}
.card h3 {
font-size: 1.05rem;
color: #58a6ff;
margin-bottom: 0.5rem;
}
.result-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(220px, 1fr));
gap: 1rem;
margin-bottom: 1.5rem;
}
.stat-card {
background: #161b22;
border: 1px solid #21262d;
border-radius: 8px;
padding: 1.2rem;
text-align: center;
}
.stat-card .value {
font-size: 2rem;
font-weight: 700;
color: #58a6ff;
}
.stat-card .label {
font-size: 0.8rem;
color: #8b949e;
text-transform: uppercase;
letter-spacing: 0.05em;
}
.stat-card.positive .value { color: #3fb950; }
.stat-card.negative .value { color: #f85149; }
table {
width: 100%;
border-collapse: collapse;
margin: 1rem 0;
}
th {
text-align: left;
padding: 0.6rem 1rem;
background: #161b22;
color: #8b949e;
font-size: 0.8rem;
text-transform: uppercase;
letter-spacing: 0.05em;
border-bottom: 1px solid #21262d;
}
td {
padding: 0.6rem 1rem;
border-bottom: 1px solid #21262d;
font-size: 0.95rem;
}
.badge {
display: inline-block;
padding: 0.15rem 0.5rem;
border-radius: 10px;
font-size: 0.75rem;
font-weight: 600;
}
.badge-supported { background: #0d2818; color: #3fb950; border: 1px solid #1b4332; }
.badge-refuted { background: #2d1215; color: #f85149; border: 1px solid #4a1c20; }
.badge-active { background: #0c2d6b; color: #58a6ff; border: 1px solid #1158c7; }
.badge-pending { background: #1c1c1c; color: #8b949e; border: 1px solid #333; }
.chart-container {
background: #161b22;
border: 1px solid #21262d;
border-radius: 8px;
padding: 1.5rem;
text-align: center;
margin: 1rem 0;
}
.next-steps {
background: #0c2d6b22;
border: 1px solid #1158c744;
border-radius: 8px;
padding: 1.5rem;
}
.next-steps h3 { color: #58a6ff; margin-bottom: 0.5rem; }
.next-steps ul { padding-left: 1.5rem; }
.next-steps li { margin-bottom: 0.3rem; }
footer {
text-align: center;
padding: 2rem 0;
color: #484f58;
font-size: 0.8rem;
border-top: 1px solid #21262d;
}
</style>
</head>
<body>
<!--
AGENT INSTRUCTIONS:
This is a starting point. Fill in, rearrange, add, or remove sections
based on what's compelling from your current research. The goal is a
research story, not a status dashboard.
Replace {{PLACEHOLDERS}} with actual content.
Embed SVG charts inline (see progress-reporting.md for the trajectory plot function).
Add additional sections as needed.
-->
<header>
<h1>{{PROJECT_TITLE}}</h1>
<p class="subtitle">{{RESEARCH_QUESTION}}</p>
<p class="meta">
<span>{{DATE}}</span>
<span>{{N_EXPERIMENTS}} experiments</span>
<span>Status: {{STATUS}}</span>
</p>
</header>
<!-- Summary stats -->
<section>
<div class="result-grid">
<div class="stat-card positive">
<div class="value">{{BEST_METRIC}}</div>
<div class="label">Best Metric</div>
</div>
<div class="stat-card">
<div class="value">{{BASELINE_METRIC}}</div>
<div class="label">Baseline</div>
</div>
<div class="stat-card positive">
<div class="value">{{IMPROVEMENT}}</div>
<div class="label">Improvement</div>
</div>
<div class="stat-card">
<div class="value">{{N_HYPOTHESES}}</div>
<div class="label">Hypotheses Tested</div>
</div>
</div>
</section>
<!-- Background and motivation -->
<section id="background">
<h2>Background & Motivation</h2>
<div class="card">
<!-- Why does this research matter? What gap are we addressing? -->
<p>{{BACKGROUND_TEXT}}</p>
</div>
</section>
<!-- Optimization trajectory - THE key visual -->
<section id="trajectory">
<h2>Optimization Trajectory</h2>
<div class="chart-container">
<!-- Embed SVG chart here. See references/progress-reporting.md
for the generate_trajectory_svg() function. -->
{{TRAJECTORY_SVG}}
</div>
</section>
<!-- Key findings -->
<section id="findings">
<h2>Key Findings</h2>
<!-- Add cards for each significant finding -->
<div class="card">
<h3>{{FINDING_1_TITLE}}</h3>
<p>{{FINDING_1_DESCRIPTION}}</p>
<!-- Include inline plots, tables, or metrics as needed -->
</div>
</section>
<!-- What was tried -->
<section id="experiments">
<h2>What We Tried</h2>
<table>
<thead>
<tr>
<th>Hypothesis</th>
<th>Change</th>
<th>Result</th>
<th>Status</th>
</tr>
</thead>
<tbody>
<!-- Add rows for notable experiments -->
<tr>
<td>{{H_ID}}</td>
<td>{{CHANGE_SUMMARY}}</td>
<td>{{METRIC_DELTA}}</td>
<td><span class="badge badge-supported">{{STATUS}}</span></td>
</tr>
</tbody>
</table>
</section>
<!-- Current understanding -->
<section id="understanding">
<h2>Current Understanding</h2>
<div class="card">
<!-- The narrative from findings.md, but presented compellingly -->
<p>{{CURRENT_UNDERSTANDING}}</p>
</div>
</section>
<!-- Next steps -->
<section id="next">
<h2>Next Steps</h2>
<div class="next-steps">
<ul>
<li>{{NEXT_STEP_1}}</li>
<li>{{NEXT_STEP_2}}</li>
<li>{{NEXT_STEP_3}}</li>
</ul>
</div>
</section>
<footer>
Generated by Autoresearch | {{DATE}}
</footer>
</body>
</html>
@@ -0,0 +1,40 @@
# Research Log
Chronological record of research decisions and actions. Append-only.
| # | Date | Type | Summary |
|---|------|------|---------|
| | | | |
<!-- Entry types:
bootstrap — initial scoping, literature search, hypothesis formation
inner-loop — experiment run and result
outer-loop — synthesis, reflection, direction decision
pivot — change in research direction
report — progress presentation generated
conclude — decision to finalize and write paper
Example entries:
| 1 | 2026-03-15 | bootstrap | Searched Semantic Scholar + arXiv for efficient transformer architectures. Found 8 relevant papers. Gap: no systematic comparison of GLU variants on small models. Formed 3 hypotheses. Baseline: NanoGPT 5-min run, val_loss=4.82. |
| 2 | 2026-03-15 | inner-loop | H1 run_001: swapped ReLU for SwiGLU in FFN. 5-min training run. val_loss=4.61 (baseline 4.82, delta -0.21). Kept. |
| 3 | 2026-03-15 | inner-loop | H1 run_002: increased FFN hidden dim from 4x to 5.3x to match SwiGLU param count. val_loss=4.58 (-0.03 vs run_001). Marginal — SwiGLU benefit mostly from gating, not extra params. |
| 4 | 2026-03-15 | inner-loop | H1 run_003: tried GEGLU instead of SwiGLU. val_loss=4.63. Slightly worse than SwiGLU. SwiGLU wins for this scale. |
| 5 | 2026-03-15 | inner-loop | H2 run_004: replaced learned positional embeddings with RoPE. val_loss=4.55 (-0.06 vs SwiGLU baseline). Promising — stacks with SwiGLU. |
| 6 | 2026-03-15 | inner-loop | H2 run_005: RoPE + SwiGLU combined. val_loss=4.41 (-0.41 vs original baseline). Best so far. |
| 7 | 2026-03-16 | outer-loop | Reviewed 5 runs. Pattern: gating mechanisms (SwiGLU) and rotary embeddings (RoPE) give independent gains that stack. Combined improvement ~9%. But WHY do they stack? Hypothesis: they operate on orthogonal aspects (FFN expressiveness vs positional encoding). Direction: DEEPEN — test if adding RMSNorm also stacks independently. |
| 8 | 2026-03-16 | inner-loop | H3 run_006: replaced LayerNorm with RMSNorm. val_loss=4.39 (-0.02). Small gain. Stacks but diminishing returns on normalization. |
| 9 | 2026-03-17 | outer-loop | 8 runs complete. Optimization plateau around val_loss=4.38. The easy architectural wins (SwiGLU, RoPE) are captured. Searched literature on training dynamics — found papers on warmup schedules at small scale. Direction: BROADEN — shift from architecture to training recipe. |
| 10 | 2026-03-17 | report | Generated progress-001.html with trajectory plot showing 9% improvement from architectural changes. |
Example entries (discovery-type research — understanding grokking):
| 1 | 2026-03-20 | bootstrap | Searched literature on grokking and delayed generalization. Found Nanda et al. progress measures, Grokfast spectral filtering. Gap: no connection to memory consolidation theory from neuroscience. 3 hypotheses formed. |
| 2 | 2026-03-20 | inner-loop | H1 run_001: trained modular addition transformer to memorization (100% train acc, 0% test). Steps to memorize: 1200. Baseline established. |
| 3 | 2026-03-20 | inner-loop | H1 run_002: continued training with standard weight decay. Grokking at step 48000. Measured progress measure throughout — sharp transition at step 44000. |
| 4 | 2026-03-20 | inner-loop | H1 run_003: inserted "sleep phase" at step 20000 (elevated weight decay + oscillatory LR for 500 steps). Grokking now at step 31000. 35% acceleration. |
| 5 | 2026-03-20 | inner-loop | H1 run_004: sleep phase at step 10000. Grokking at step 27000. Earlier sleep = earlier grokking. |
| 6 | 2026-03-20 | inner-loop | H1 run_005: sleep phase at step 5000 (before full memorization). Grokking at step 38000. Too early hurts — model hadn't memorized enough for consolidation to work. |
| 7 | 2026-03-21 | outer-loop | Reviewed 5 runs. Clear pattern: sleep phases accelerate grokking but only AFTER memorization is complete. This matches memory consolidation theory exactly — you need memories formed before consolidation can reorganize them. Searched for neural slow-wave sleep literature. The weight decay + oscillatory LR during sleep phases mimics synaptic downscaling. Direction: DEEPEN — sweep sleep timing relative to memorization completion. |
| 8 | 2026-03-21 | inner-loop | H1.1 run_006-010: swept sleep insertion at 80%, 100%, 120%, 150%, 200% of memorization step. Sweet spot at 110-120%. Consistent across 3 seeds. |
| 9 | 2026-03-22 | outer-loop | 10 runs complete. The story is clear: neural networks "dream to learn" just like brains — consolidation after encoding, not during. Grokfast achieves similar acceleration through a different mechanism (gradient spectral filtering). Next: compare gradient spectra during our sleep phases vs Grokfast filtering to see if they converge on the same signal. Direction: BROADEN. |
| 10 | 2026-03-22 | report | Generated progress-001.html with sleep timing vs grokking step plot. Key visual: sweet spot curve mirrors neuroscience memory consolidation window. |
-->
@@ -0,0 +1,57 @@
# Research State — Central Project Tracking
# Copy this template to your project root and fill in as you go.
# Updated by the agent after each experiment and reflection.
project:
title: ""
question: "" # The core research question
status: active # active | paused | concluded
started: "" # ISO date
domain: "" # e.g., "mechanistic interpretability", "RL training"
literature:
key_papers: []
# - id: "liu2025superposition"
# title: "Superposition Yields Robust Neural Scaling"
# authors: "Liu et al."
# year: 2025
# relevance: "Proves ETF structure in LM heads"
open_problems: [] # Gaps identified from literature
evidence_gaps: [] # What's missing in the field
hypotheses:
# List of all hypotheses, active and completed
# - id: H1
# statement: "Testable claim with clear prediction"
# status: pending # pending | active | supported | refuted | inconclusive
# motivation: "Why this is worth testing"
# parent: null # null for root, parent ID (e.g., H1) for sub-hypotheses
# priority: medium # high | medium | low
experiments:
proxy_metric: "" # What we're optimizing and how to compute it
baseline_value: null # Starting point
best_value: null # Best achieved so far
total_runs: 0
trajectory: []
# - run_id: "run_001"
# hypothesis: "H1"
# metric_value: null
# delta: null # Change from baseline
# wall_time_min: null
# change_summary: ""
# timestamp: ""
outer_loop:
cycle: 0 # How many outer loop reflections so far
last_direction: null # deepen | broaden | pivot | conclude
last_reflection: "" # Brief summary of last reflection decision
workspace:
# Track key resource locations
findings: "findings.md"
log: "research-log.md"
literature_dir: "literature/"
experiments_dir: "experiments/"
to_human_dir: "to_human/"
paper_dir: "paper/"
@@ -0,0 +1,5 @@
# Skills Coming Soon
This directory will contain high-quality AI research skills for model architecture.
See [CONTRIBUTING.md](../CONTRIBUTING.md) for how to contribute.
@@ -0,0 +1,469 @@
---
name: implementing-llms-litgpt
description: Implements and trains LLMs using Lightning AI's LitGPT with 20+ pretrained architectures (Llama, Gemma, Phi, Qwen, Mistral). Use when need clean model implementations, educational understanding of architectures, or production fine-tuning with LoRA/QLoRA. Single-file implementations, no abstraction layers.
version: 1.0.0
author: Orchestra Research
license: MIT
tags: [Model Architecture, LitGPT, Lightning AI, LLM Implementation, LoRA, QLoRA, Fine-Tuning, Llama, Gemma, Phi, Mistral, Educational]
dependencies: [litgpt, torch, transformers]
---
# LitGPT - Clean LLM Implementations
## Quick start
LitGPT provides 20+ pretrained LLM implementations with clean, readable code and production-ready training workflows.
**Installation**:
```bash
pip install 'litgpt[extra]'
```
**Load and use any model**:
```python
from litgpt import LLM
# Load pretrained model
llm = LLM.load("microsoft/phi-2")
# Generate text
result = llm.generate(
"What is the capital of France?",
max_new_tokens=50,
temperature=0.7
)
print(result)
```
**List available models**:
```bash
litgpt download list
```
## Common workflows
### Workflow 1: Fine-tune on custom dataset
Copy this checklist:
```
Fine-Tuning Setup:
- [ ] Step 1: Download pretrained model
- [ ] Step 2: Prepare dataset
- [ ] Step 3: Configure training
- [ ] Step 4: Run fine-tuning
```
**Step 1: Download pretrained model**
```bash
# Download Llama 3 8B
litgpt download meta-llama/Meta-Llama-3-8B
# Download Phi-2 (smaller, faster)
litgpt download microsoft/phi-2
# Download Gemma 2B
litgpt download google/gemma-2b
```
Models are saved to `checkpoints/` directory.
**Step 2: Prepare dataset**
LitGPT supports multiple formats:
**Alpaca format** (instruction-response):
```json
[
{
"instruction": "What is the capital of France?",
"input": "",
"output": "The capital of France is Paris."
},
{
"instruction": "Translate to Spanish: Hello, how are you?",
"input": "",
"output": "Hola, ¿cómo estás?"
}
]
```
Save as `data/my_dataset.json`.
**Step 3: Configure training**
```bash
# Full fine-tuning (requires 40GB+ GPU for 7B models)
litgpt finetune \
meta-llama/Meta-Llama-3-8B \
--data JSON \
--data.json_path data/my_dataset.json \
--train.max_steps 1000 \
--train.learning_rate 2e-5 \
--train.micro_batch_size 1 \
--train.global_batch_size 16
# LoRA fine-tuning (efficient, 16GB GPU)
litgpt finetune_lora \
microsoft/phi-2 \
--data JSON \
--data.json_path data/my_dataset.json \
--lora_r 16 \
--lora_alpha 32 \
--lora_dropout 0.05 \
--train.max_steps 1000 \
--train.learning_rate 1e-4
```
**Step 4: Run fine-tuning**
Training saves checkpoints to `out/finetune/` automatically.
Monitor training:
```bash
# View logs
tail -f out/finetune/logs.txt
# TensorBoard (if using --train.logger_name tensorboard)
tensorboard --logdir out/finetune/lightning_logs
```
### Workflow 2: LoRA fine-tuning on single GPU
Most memory-efficient option.
```
LoRA Training:
- [ ] Step 1: Choose base model
- [ ] Step 2: Configure LoRA parameters
- [ ] Step 3: Train with LoRA
- [ ] Step 4: Merge LoRA weights (optional)
```
**Step 1: Choose base model**
For limited GPU memory (12-16GB):
- **Phi-2** (2.7B) - Best quality/size tradeoff
- **Llama 3 1B** - Smallest, fastest
- **Gemma 2B** - Good reasoning
**Step 2: Configure LoRA parameters**
```bash
litgpt finetune_lora \
microsoft/phi-2 \
--data JSON \
--data.json_path data/my_dataset.json \
--lora_r 16 \ # LoRA rank (8-64, higher=more capacity)
--lora_alpha 32 \ # LoRA scaling (typically 2×r)
--lora_dropout 0.05 \ # Prevent overfitting
--lora_query true \ # Apply LoRA to query projection
--lora_key false \ # Usually not needed
--lora_value true \ # Apply LoRA to value projection
--lora_projection true \ # Apply LoRA to output projection
--lora_mlp false \ # Usually not needed
--lora_head false # Usually not needed
```
LoRA rank guide:
- `r=8`: Lightweight, 2-4MB adapters
- `r=16`: Standard, good quality
- `r=32`: High capacity, use for complex tasks
- `r=64`: Maximum quality, 4× larger adapters
**Step 3: Train with LoRA**
```bash
litgpt finetune_lora \
microsoft/phi-2 \
--data JSON \
--data.json_path data/my_dataset.json \
--lora_r 16 \
--train.epochs 3 \
--train.learning_rate 1e-4 \
--train.micro_batch_size 4 \
--train.global_batch_size 32 \
--out_dir out/phi2-lora
# Memory usage: ~8-12GB for Phi-2 with LoRA
```
**Step 4: Merge LoRA weights** (optional)
Merge LoRA adapters into base model for deployment:
```bash
litgpt merge_lora \
out/phi2-lora/final \
--out_dir out/phi2-merged
```
Now use merged model:
```python
from litgpt import LLM
llm = LLM.load("out/phi2-merged")
```
### Workflow 3: Pretrain from scratch
Train new model on your domain data.
```
Pretraining:
- [ ] Step 1: Prepare pretraining dataset
- [ ] Step 2: Configure model architecture
- [ ] Step 3: Set up multi-GPU training
- [ ] Step 4: Launch pretraining
```
**Step 1: Prepare pretraining dataset**
LitGPT expects tokenized data. Use `prepare_dataset.py`:
```bash
python scripts/prepare_dataset.py \
--source_path data/my_corpus.txt \
--checkpoint_dir checkpoints/tokenizer \
--destination_path data/pretrain \
--split train,val
```
**Step 2: Configure model architecture**
Edit config file or use existing:
```python
# config/pythia-160m.yaml
model_name: pythia-160m
block_size: 2048
vocab_size: 50304
n_layer: 12
n_head: 12
n_embd: 768
rotary_percentage: 0.25
parallel_residual: true
bias: true
```
**Step 3: Set up multi-GPU training**
```bash
# Single GPU
litgpt pretrain \
--config config/pythia-160m.yaml \
--data.data_dir data/pretrain \
--train.max_tokens 10_000_000_000
# Multi-GPU with FSDP
litgpt pretrain \
--config config/pythia-1b.yaml \
--data.data_dir data/pretrain \
--devices 8 \
--train.max_tokens 100_000_000_000
```
**Step 4: Launch pretraining**
For large-scale pretraining on cluster:
```bash
# Using SLURM
sbatch --nodes=8 --gpus-per-node=8 \
pretrain_script.sh
# pretrain_script.sh content:
litgpt pretrain \
--config config/pythia-1b.yaml \
--data.data_dir /shared/data/pretrain \
--devices 8 \
--num_nodes 8 \
--train.global_batch_size 512 \
--train.max_tokens 300_000_000_000
```
### Workflow 4: Convert and deploy model
Export LitGPT models for production.
```
Model Deployment:
- [ ] Step 1: Test inference locally
- [ ] Step 2: Quantize model (optional)
- [ ] Step 3: Convert to GGUF (for llama.cpp)
- [ ] Step 4: Deploy with API
```
**Step 1: Test inference locally**
```python
from litgpt import LLM
llm = LLM.load("out/phi2-lora/final")
# Single generation
print(llm.generate("What is machine learning?"))
# Streaming
for token in llm.generate("Explain quantum computing", stream=True):
print(token, end="", flush=True)
# Batch inference
prompts = ["Hello", "Goodbye", "Thank you"]
results = [llm.generate(p) for p in prompts]
```
**Step 2: Quantize model** (optional)
Reduce model size with minimal quality loss:
```bash
# 8-bit quantization (50% size reduction)
litgpt convert_lit_checkpoint \
out/phi2-lora/final \
--dtype bfloat16 \
--quantize bnb.nf4
# 4-bit quantization (75% size reduction)
litgpt convert_lit_checkpoint \
out/phi2-lora/final \
--quantize bnb.nf4-dq # Double quantization
```
**Step 3: Convert to GGUF** (for llama.cpp)
```bash
python scripts/convert_lit_checkpoint.py \
--checkpoint_path out/phi2-lora/final \
--output_path models/phi2.gguf \
--model_name microsoft/phi-2
```
**Step 4: Deploy with API**
```python
from fastapi import FastAPI
from litgpt import LLM
app = FastAPI()
llm = LLM.load("out/phi2-lora/final")
@app.post("/generate")
def generate(prompt: str, max_tokens: int = 100):
result = llm.generate(
prompt,
max_new_tokens=max_tokens,
temperature=0.7
)
return {"response": result}
# Run: uvicorn api:app --host 0.0.0.0 --port 8000
```
## When to use vs alternatives
**Use LitGPT when:**
- Want to understand LLM architectures (clean, readable code)
- Need production-ready training recipes
- Educational purposes or research
- Prototyping new model ideas
- Lightning ecosystem user
**Use alternatives instead:**
- **Axolotl/TRL**: More fine-tuning features, YAML configs
- **Megatron-Core**: Maximum performance for >70B models
- **HuggingFace Transformers**: Broadest model support
- **vLLM**: Inference-only (no training)
## Common issues
**Issue: Out of memory during fine-tuning**
Use LoRA instead of full fine-tuning:
```bash
# Instead of litgpt finetune (requires 40GB+)
litgpt finetune_lora # Only needs 12-16GB
```
Or enable gradient checkpointing:
```bash
litgpt finetune_lora \
... \
--train.gradient_accumulation_iters 4 # Accumulate gradients
```
**Issue: Training too slow**
Enable Flash Attention (built-in, automatic on compatible hardware):
```python
# Already enabled by default on Ampere+ GPUs (A100, RTX 30/40 series)
# No configuration needed
```
Use smaller micro-batch and accumulate:
```bash
--train.micro_batch_size 1 \
--train.global_batch_size 32 \
--train.gradient_accumulation_iters 32 # Effective batch=32
```
**Issue: Model not loading**
Check model name:
```bash
# List all available models
litgpt download list
# Download if not exists
litgpt download meta-llama/Meta-Llama-3-8B
```
Verify checkpoints directory:
```bash
ls checkpoints/
# Should see: meta-llama/Meta-Llama-3-8B/
```
**Issue: LoRA adapters too large**
Reduce LoRA rank:
```bash
--lora_r 8 # Instead of 16 or 32
```
Apply LoRA to fewer layers:
```bash
--lora_query true \
--lora_value true \
--lora_projection false \ # Disable this
--lora_mlp false # And this
```
## Advanced topics
**Supported architectures**: See [references/supported-models.md](references/supported-models.md) for complete list of 20+ model families with sizes and capabilities.
**Training recipes**: See [references/training-recipes.md](references/training-recipes.md) for proven hyperparameter configurations for pretraining and fine-tuning.
**FSDP configuration**: See [references/distributed-training.md](references/distributed-training.md) for multi-GPU training with Fully Sharded Data Parallel.
**Custom architectures**: See [references/custom-models.md](references/custom-models.md) for implementing new model architectures in LitGPT style.
## Hardware requirements
- **GPU**: NVIDIA (CUDA 11.8+), AMD (ROCm), Apple Silicon (MPS)
- **Memory**:
- Inference (Phi-2): 6GB
- LoRA fine-tuning (7B): 16GB
- Full fine-tuning (7B): 40GB+
- Pretraining (1B): 24GB
- **Storage**: 5-50GB per model (depending on size)
## Resources
- GitHub: https://github.com/Lightning-AI/litgpt
- Docs: https://lightning.ai/docs/litgpt
- Tutorials: https://lightning.ai/docs/litgpt/tutorials
- Model zoo: 20+ pretrained architectures (Llama, Gemma, Phi, Qwen, Mistral, Mixtral, Falcon, etc.)
@@ -0,0 +1,568 @@
# Custom Models
Guide to implementing custom model architectures in LitGPT.
## Overview
LitGPT's clean, single-file implementations make it easy to create custom architectures. You can extend the base `GPT` class or create entirely new models.
**Use cases**:
- Implementing new research architectures
- Adapting models for specific domains
- Experimenting with attention mechanisms
- Adding custom layers or components
## Key Files and Classes
### Core Architecture (`litgpt/model.py`)
**Main classes**:
- `GPT`: Top-level model class
- `Block`: Transformer block (attention + MLP)
- `CausalSelfAttention`: Attention mechanism
- `MLP`: Feed-forward network
- `RMSNorm` / `LayerNorm`: Normalization layers
**Configuration** (`litgpt/config.py`):
- `Config`: Base configuration dataclass
- Model-specific configs: `LlamaConfig`, `MistralConfig`, `PhiConfig`, etc.
## Custom Architecture Workflow
### Step 1: Define Configuration
Create a `Config` dataclass with your model's hyperparameters:
```python
from dataclasses import dataclass
from litgpt.config import Config
@dataclass
class MyModelConfig(Config):
"""Configuration for my custom model."""
# Standard parameters
name: str = "my-model-7b"
block_size: int = 4096
vocab_size: int = 32000
n_layer: int = 32
n_head: int = 32
n_embd: int = 4096
# Custom parameters
custom_param: float = 0.1
use_custom_attention: bool = True
# Optional: override defaults
rope_base: int = 10000
intermediate_size: int = 11008
```
### Step 2: Implement Custom Components
#### Option A: Custom Attention
```python
from litgpt.model import CausalSelfAttention
import torch
import torch.nn as nn
class CustomAttention(CausalSelfAttention):
"""Custom attention mechanism."""
def __init__(self, config):
super().__init__(config)
# Add custom components
self.custom_proj = nn.Linear(config.n_embd, config.n_embd)
self.custom_param = config.custom_param
def forward(self, x, mask=None, input_pos=None):
B, T, C = x.size()
# Standard Q, K, V projections
q = self.attn(x)
k = self.attn(x)
v = self.attn(x)
# Custom modification
q = q + self.custom_proj(x) * self.custom_param
# Rest of attention computation
q = q.view(B, T, self.n_head, self.head_size)
k = k.view(B, T, self.n_query_groups, self.head_size)
v = v.view(B, T, self.n_query_groups, self.head_size)
# Scaled dot-product attention
y = self.scaled_dot_product_attention(q, k, v, mask=mask)
y = y.reshape(B, T, C)
return self.proj(y)
```
#### Option B: Custom MLP
```python
from litgpt.model import MLP
class CustomMLP(MLP):
"""Custom feed-forward network."""
def __init__(self, config):
super().__init__(config)
# Add custom layers
self.custom_layer = nn.Linear(config.intermediate_size, config.intermediate_size)
def forward(self, x):
x = self.fc_1(x)
x = self.act(x)
x = self.custom_layer(x) # Custom modification
x = self.fc_2(x)
return x
```
#### Option C: Custom Block
```python
from litgpt.model import Block
class CustomBlock(Block):
"""Custom transformer block."""
def __init__(self, config):
super().__init__(config)
# Replace attention or MLP
self.attn = CustomAttention(config)
# Or: self.mlp = CustomMLP(config)
# Add custom components
self.custom_norm = nn.LayerNorm(config.n_embd)
def forward(self, x, input_pos=None, mask=None):
# Custom forward pass
h = self.norm_1(x)
h = self.attn(h, mask=mask, input_pos=input_pos)
x = x + h
# Custom normalization
x = x + self.custom_norm(x)
x = x + self.mlp(self.norm_2(x))
return x
```
### Step 3: Create Custom GPT Model
```python
from litgpt.model import GPT
import torch.nn as nn
class CustomGPT(GPT):
"""Custom GPT model."""
def __init__(self, config: MyModelConfig):
# Don't call super().__init__() - we reimplement
nn.Module.__init__(self)
self.config = config
# Standard components
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(config.vocab_size, config.n_embd),
h=nn.ModuleList(CustomBlock(config) for _ in range(config.n_layer)),
ln_f=nn.LayerNorm(config.n_embd),
)
)
# Custom components
if config.use_custom_attention:
self.custom_embedding = nn.Linear(config.n_embd, config.n_embd)
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
"""Initialize weights (required)."""
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx, input_pos=None):
"""Forward pass (must match base signature)."""
B, T = idx.size()
# Token embeddings
x = self.transformer.wte(idx)
# Custom embedding modification
if self.config.use_custom_attention:
x = x + self.custom_embedding(x)
# Transformer blocks
for block in self.transformer.h:
x = block(x, input_pos=input_pos)
# Final norm + LM head
x = self.transformer.ln_f(x)
return self.lm_head(x)
```
### Step 4: Register Configuration
Add your config to `litgpt/config.py`:
```python
# In litgpt/config.py
configs = [
# ... existing configs ...
# My custom model
dict(
name="my-model-7b",
hf_config=dict(org="myorg", name="my-model-7b"),
block_size=4096,
vocab_size=32000,
n_layer=32,
n_head=32,
n_embd=4096,
custom_param=0.1,
),
]
```
### Step 5: Use Your Custom Model
```python
from litgpt.api import LLM
from my_model import CustomGPT, MyModelConfig
# Initialize
config = MyModelConfig()
model = CustomGPT(config)
# Wrap with LLM API
llm = LLM(model=model, tokenizer_dir="path/to/tokenizer")
# Generate
result = llm.generate("Once upon a time", max_new_tokens=100)
print(result)
```
## Real Example: Adapter Fine-tuning
LitGPT's `Adapter` implementation shows a complete custom architecture:
### Adapter Configuration
```python
@dataclass
class Config(BaseConfig):
"""Adds adapter-specific parameters."""
adapter_prompt_length: int = 10
adapter_start_layer: int = 2
```
### Adapter GPT Model
```python
class GPT(BaseModel):
"""GPT model with adapter layers."""
def __init__(self, config: Config):
nn.Module.__init__(self)
self.config = config
# Standard components
self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)
self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)),
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
)
)
# Adapter-specific: gating factor
self.gating_factor = torch.nn.Parameter(torch.zeros(1))
```
### Adapter Block
```python
class Block(BaseBlock):
"""Transformer block with adapter."""
def __init__(self, config: Config, block_idx: int):
super().__init__()
self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
self.attn = CausalSelfAttention(config, block_idx)
self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps)
self.mlp = config.mlp_class(config)
# Adapter: add prefix for certain layers
self.adapter_wte = (
nn.Embedding(config.adapter_prompt_length, config.n_embd)
if block_idx >= config.adapter_start_layer
else None
)
```
### Adapter Attention
```python
class CausalSelfAttention(BaseCausalSelfAttention):
"""Attention with adapter prompts."""
def forward(self, x: torch.Tensor, ...) -> torch.Tensor:
B, T, C = x.size()
# Add adapter prefix if enabled
if self.adapter_wte is not None:
adapter_prompts = self.adapter_wte(
torch.arange(self.adapter_prompt_length, device=x.device)
)
adapter_prompts = adapter_prompts.unsqueeze(0).expand(B, -1, -1)
x = torch.cat([adapter_prompts, x], dim=1)
# Standard attention with gating
q, k, v = self.attn(x).split(self.n_embd, dim=2)
y = self.scaled_dot_product_attention(q, k, v, mask=mask)
# Apply gating factor
y = y * self.gating_factor
return self.proj(y)
```
See full implementation: `litgpt/finetune/adapter.py`
## Real Example: AdapterV2
AdapterV2 shows custom linear layers:
### AdapterV2Linear
```python
class AdapterV2Linear(torch.nn.Module):
"""Linear layer with low-rank adapter."""
def __init__(self, in_features, out_features, adapter_rank=8, **kwargs):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, **kwargs)
# Adapter: low-rank bottleneck
self.adapter_down = torch.nn.Linear(in_features, adapter_rank, bias=False)
self.adapter_up = torch.nn.Linear(adapter_rank, out_features, bias=False)
# Initialize adapter to identity
torch.nn.init.zeros_(self.adapter_up.weight)
def forward(self, x):
# Original linear transformation
out = self.linear(x)
# Add adapter contribution
adapter_out = self.adapter_up(self.adapter_down(x))
return out + adapter_out
```
See full implementation: `litgpt/finetune/adapter_v2.py`
## Custom Model Checklist
- [ ] Define `Config` dataclass with all hyperparameters
- [ ] Implement custom components (Attention, MLP, Block)
- [ ] Create custom `GPT` class
- [ ] Implement `_init_weights()` for proper initialization
- [ ] Implement `forward()` matching base signature
- [ ] Register configuration in `litgpt/config.py`
- [ ] Test with small model (100M params) first
- [ ] Verify training convergence
- [ ] Profile memory usage
## Testing Your Custom Model
### Unit Test
```python
import torch
from my_model import CustomGPT, MyModelConfig
def test_custom_model():
"""Test custom model forward pass."""
config = MyModelConfig(
n_layer=2,
n_head=4,
n_embd=128,
vocab_size=1000,
block_size=256,
)
model = CustomGPT(config)
model.eval()
# Test forward pass
batch_size = 2
seq_length = 16
idx = torch.randint(0, config.vocab_size, (batch_size, seq_length))
with torch.no_grad():
logits = model(idx)
assert logits.shape == (batch_size, seq_length, config.vocab_size)
print("✓ Forward pass works")
if __name__ == "__main__":
test_custom_model()
```
### Training Test
```python
from litgpt.api import LLM
def test_training():
"""Test custom model training."""
config = MyModelConfig(n_layer=2, n_head=4, n_embd=128)
model = CustomGPT(config)
# Small dataset for testing
data = [
{"instruction": "Test", "input": "", "output": "OK"}
]
# Should run without errors
llm = LLM(model=model)
# ... training code ...
print("✓ Training works")
```
## Common Patterns
### Adding New Attention Mechanism
```python
class MyAttention(nn.Module):
"""Template for custom attention."""
def __init__(self, config):
super().__init__()
self.n_head = config.n_head
self.n_embd = config.n_embd
self.head_size = self.n_embd // self.n_head
# Q, K, V projections
self.q_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.k_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.v_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
# Output projection
self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
def forward(self, x, mask=None):
B, T, C = x.size()
# Project Q, K, V
q = self.q_proj(x).view(B, T, self.n_head, self.head_size)
k = self.k_proj(x).view(B, T, self.n_head, self.head_size)
v = self.v_proj(x).view(B, T, self.n_head, self.head_size)
# Custom attention computation here
# attn = custom_attention_function(q, k, v, mask)
# Output projection
out = self.out_proj(attn.reshape(B, T, C))
return out
```
### Adding Mixture of Experts
```python
class MoELayer(nn.Module):
"""Mixture of Experts layer."""
def __init__(self, config):
super().__init__()
self.num_experts = config.num_experts
self.top_k = config.moe_top_k
# Router
self.router = nn.Linear(config.n_embd, self.num_experts)
# Experts
self.experts = nn.ModuleList([
MLP(config) for _ in range(self.num_experts)
])
def forward(self, x):
B, T, C = x.size()
# Route tokens to experts
router_logits = self.router(x) # (B, T, num_experts)
router_probs = torch.softmax(router_logits, dim=-1)
# Select top-k experts
top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
# Process through selected experts
output = torch.zeros_like(x)
for i in range(self.top_k):
expert_idx = top_k_indices[:, :, i]
expert_prob = top_k_probs[:, :, i:i+1]
# Route to expert
for expert_id in range(self.num_experts):
mask = (expert_idx == expert_id)
if mask.any():
expert_out = self.experts[expert_id](x[mask])
output[mask] += expert_out * expert_prob[mask]
return output
```
### Adding Positional Encoding
```python
class CustomPositionalEncoding(nn.Module):
"""Custom positional encoding."""
def __init__(self, config):
super().__init__()
self.n_embd = config.n_embd
self.register_buffer(
"pos_encoding",
self._create_encoding(config.block_size, config.n_embd)
)
def _create_encoding(self, max_len, d_model):
"""Create positional encoding matrix."""
pos = torch.arange(max_len).unsqueeze(1)
div = torch.exp(torch.arange(0, d_model, 2) * -(torch.log(torch.tensor(10000.0)) / d_model))
encoding = torch.zeros(max_len, d_model)
encoding[:, 0::2] = torch.sin(pos * div)
encoding[:, 1::2] = torch.cos(pos * div)
return encoding
def forward(self, x):
"""Add positional encoding."""
return x + self.pos_encoding[:x.size(1), :]
```
## Debugging Tips
1. **Start small**: Test with 2 layers, 128 hidden size
2. **Check shapes**: Print tensor shapes at each step
3. **Verify gradients**: Ensure all parameters have gradients
4. **Compare to base**: Run same config with base `GPT` model
5. **Profile memory**: Use `torch.cuda.memory_summary()`
## References
- Base model: `litgpt/model.py`
- Configuration: `litgpt/config.py`
- Adapter example: `litgpt/finetune/adapter.py`
- AdapterV2 example: `litgpt/finetune/adapter_v2.py`
- LoRA example: `litgpt/finetune/lora.py`
@@ -0,0 +1,451 @@
# Distributed Training
Guide to FSDP (Fully Sharded Data Parallel) distributed training in LitGPT for scaling to multiple GPUs and nodes.
## Overview
LitGPT uses **Lightning Fabric** with **FSDP** to distribute training across multiple GPUs. FSDP shards model parameters, gradients, and optimizer states to enable training models larger than single-GPU memory.
**When to use FSDP**:
- Model doesn't fit on single GPU
- Want faster training with multi-GPU
- Training models >7B parameters
- Need to scale across multiple nodes
## Quick Start
### Single Node Multi-GPU
```bash
# Train Llama 2 7B on 4 GPUs
litgpt finetune_lora meta-llama/Llama-2-7b-hf \
--devices 4 \
--data JSON \
--data.json_path data/alpaca.json
```
FSDP is **automatically enabled** when `devices > 1`.
### Multi-Node Training
```bash
# Train on 2 nodes with 8 GPUs each (16 total)
litgpt finetune_lora meta-llama/Llama-2-70b-hf \
--devices 8 \
--num_nodes 2 \
--data JSON \
--data.json_path data/alpaca.json
```
## FSDP Configuration
### Default FSDP Strategy
When multiple devices are used, LitGPT applies this FSDP configuration:
```python
from lightning.fabric.strategies import FSDPStrategy
from litgpt.model import Block
strategy = FSDPStrategy(
auto_wrap_policy={Block},
state_dict_type="full",
sharding_strategy="HYBRID_SHARD"
)
```
**Parameters**:
- `auto_wrap_policy={Block}`: Automatically wraps each transformer `Block` with FSDP
- `state_dict_type="full"`: Saves full model (assembled on rank 0) for easy deployment
- `sharding_strategy="HYBRID_SHARD"`: Shards parameters, gradients, and optimizer states
### Sharding Strategies
| Strategy | Shards | Communication | Use Case |
|----------|--------|---------------|----------|
| `FULL_SHARD` (ZeRO-3) | Params + Grads + Optim | All-gather before forward/backward | Maximum memory savings |
| `SHARD_GRAD_OP` (ZeRO-2) | Grads + Optim only | Reduce-scatter after backward | Faster than FULL_SHARD |
| `HYBRID_SHARD` (default) | All (hybrid across nodes) | Optimized for multi-node | Best for clusters |
| `NO_SHARD` | None | Broadcast | Single GPU (no FSDP) |
**Recommendation**: Use default `HYBRID_SHARD` for multi-node, or `FULL_SHARD` for single-node multi-GPU.
### State Dict Types
| Type | Behavior | Use Case |
|------|----------|----------|
| `full` (default) | Gathers all shards on rank 0, saves single file | Easy deployment, inference |
| `sharded` | Each rank saves its shard separately | Faster checkpointing, resume training |
### Auto-Wrap Policy
FSDP wraps model components based on `auto_wrap_policy`:
```python
auto_wrap_policy={Block} # Wrap each transformer block
```
This means each `Block` (transformer layer) is independently sharded across GPUs. For a 32-layer model on 4 GPUs, each GPU holds ~8 layer shards.
## Thunder FSDP (Advanced)
LitGPT includes an experimental **Thunder** extension with enhanced FSDP:
```bash
litgpt pretrain tiny-llama-1.1b \
--devices 8 \
--num_nodes 1 \
--compiler thunder \
--strategy fsdp
```
### Thunder FSDP Configuration
```python
from extensions.thunder.pretrain import ThunderFSDPStrategy
strategy = ThunderFSDPStrategy(
sharding_strategy="ZERO3",
bucketing_strategy="BLOCK",
state_dict_type="full",
jit=False,
)
```
**Additional Parameters**:
- `sharding_strategy`: `"ZERO3"` (full shard), `"ZERO2"` (grad/optim only)
- `bucketing_strategy`: `"BLOCK"` (combine ops per block), `"LAYER"` (per layer), `"NONE"` (no bucketing)
- `jit`: Whether to apply `thunder.jit(model)` for optimization
- `executors`: Tuple of Thunder executors to enable
**Bucketing Strategy**:
- `"BLOCK"` (default): Combines collective operations for layer blocks → fewer communication calls
- `"LAYER"`: Combines per layer class
- `"NONE"`: No bucketing → more fine-grained but more overhead
## Pretraining with FSDP
### Single Node
```bash
litgpt pretrain tiny-llama-1.1b \
--devices 8 \
--train.global_batch_size 512 \
--train.micro_batch_size 8 \
--data Alpaca2k
```
**Memory calculation**:
- TinyLlama 1.1B: ~4GB model + ~4GB gradients + ~8GB optimizer = 16GB per GPU without FSDP
- With FSDP on 8 GPUs: 16GB / 8 = 2GB per GPU ✅ Fits easily
### Multi-Node
```bash
# Launch on 4 nodes with 8 GPUs each (32 total)
litgpt pretrain llama-2-7b \
--devices 8 \
--num_nodes 4 \
--train.global_batch_size 1024 \
--train.micro_batch_size 2 \
--data RedPajama
```
**Memory calculation**:
- Llama 2 7B: ~28GB model + ~28GB gradients + ~56GB optimizer = 112GB total
- With FSDP on 32 GPUs: 112GB / 32 = 3.5GB per GPU ✅
## Fine-tuning with FSDP
### LoRA Fine-tuning (Recommended)
LoRA fine-tuning with FSDP for >7B models:
```bash
# Llama 2 70B LoRA on 8 GPUs
litgpt finetune_lora meta-llama/Llama-2-70b-hf \
--devices 8 \
--data JSON \
--data.json_path data/alpaca.json \
--train.global_batch_size 16 \
--train.micro_batch_size 1 \
--lora_r 8
```
**Why LoRA with FSDP**:
- Base model sharded with FSDP (memory efficient)
- Only LoRA adapters trained (fast)
- Best of both worlds for large models
### Full Fine-tuning
Full fine-tuning with FSDP:
```bash
# Llama 2 7B full fine-tune on 4 GPUs
litgpt finetune_full meta-llama/Llama-2-7b-hf \
--devices 4 \
--data JSON \
--data.json_path data/alpaca.json \
--train.global_batch_size 16 \
--train.micro_batch_size 1 \
--train.learning_rate 3e-5
```
## Mixed Precision
FSDP works with mixed precision for memory savings and speedup:
```bash
# BF16 mixed precision (recommended for A100/H100)
litgpt pretrain tiny-llama-1.1b \
--devices 8 \
--precision bf16-mixed
# FP16 mixed precision (V100 compatible)
litgpt pretrain tiny-llama-1.1b \
--devices 8 \
--precision 16-mixed
```
**Precision options**:
- `bf16-mixed`: BF16 for computation, FP32 for master weights (best for Ampere+)
- `16-mixed`: FP16 for computation, FP32 for master weights (V100)
- `32-true`: Full FP32 (debugging only, slow)
## Gradient Accumulation
Simulate larger batch sizes with gradient accumulation:
```bash
# Simulate global_batch_size=512 with micro_batch_size=2
litgpt pretrain tiny-llama-1.1b \
--devices 8 \
--train.global_batch_size 512 \
--train.micro_batch_size 2
# Accumulates over 512/(8*2) = 32 steps per optimizer update
```
**Formula**:
```
Gradient accumulation steps = global_batch_size / (devices × micro_batch_size)
```
## Memory Optimization
### Out of Memory? Try These
1. **Increase devices**:
```bash
--devices 8 # Instead of 4
```
2. **Reduce micro batch size**:
```bash
--train.micro_batch_size 1 # Instead of 2
```
3. **Lower precision**:
```bash
--precision bf16-mixed # Instead of 32-true
```
4. **Use FULL_SHARD**:
```python
strategy = FSDPStrategy(
sharding_strategy="FULL_SHARD" # Maximum memory savings
)
```
5. **Enable activation checkpointing** (implemented in model):
```python
# Recomputes activations during backward pass
# Trades compute for memory
```
6. **Use QLoRA**:
```bash
litgpt finetune_lora meta-llama/Llama-2-7b-hf \
--quantize bnb.nf4 \
--devices 1 # May not need FSDP with quantization
```
## Checkpointing
### Save Checkpoints
FSDP automatically handles checkpoint saving:
```bash
litgpt pretrain tiny-llama-1.1b \
--devices 8 \
--out_dir checkpoints/tinyllama-pretrain
# Saves to: checkpoints/tinyllama-pretrain/final/lit_model.pth
```
With `state_dict_type="full"` (default), rank 0 assembles full model and saves single file.
### Resume Training
```bash
litgpt pretrain tiny-llama-1.1b \
--devices 8 \
--resume checkpoints/tinyllama-pretrain/
# Automatically loads latest checkpoint
```
### Convert to HuggingFace
```bash
python scripts/convert_lit_checkpoint.py \
--checkpoint_path checkpoints/tinyllama-pretrain/final/lit_model.pth \
--output_dir models/tinyllama-hf
```
## Performance Tuning
### Communication Backends
LitGPT uses NCCL for GPU communication:
```bash
# Default (NCCL auto-configured)
litgpt pretrain tiny-llama-1.1b --devices 8
# Explicit NCCL settings (advanced)
NCCL_DEBUG=INFO \
NCCL_IB_DISABLE=0 \
litgpt pretrain tiny-llama-1.1b --devices 8
```
**NCCL Environment Variables**:
- `NCCL_DEBUG=INFO`: Enable debug logging
- `NCCL_IB_DISABLE=0`: Use InfiniBand (if available)
- `NCCL_SOCKET_IFNAME=eth0`: Specify network interface
### Multi-Node Setup
**Option 1: SLURM**
```bash
#!/bin/bash
#SBATCH --nodes=4
#SBATCH --gpus-per-node=8
#SBATCH --ntasks-per-node=1
srun litgpt pretrain llama-2-7b \
--devices 8 \
--num_nodes 4 \
--data RedPajama
```
**Option 2: torchrun**
```bash
# On each node, run:
torchrun \
--nproc_per_node=8 \
--nnodes=4 \
--node_rank=$NODE_RANK \
--master_addr=$MASTER_ADDR \
--master_port=29500 \
-m litgpt pretrain llama-2-7b
```
### Profiling
Enable profiling to identify bottlenecks:
```bash
litgpt pretrain tiny-llama-1.1b \
--devices 8 \
--train.max_steps 100 \
--profile
# Generates profiling report
```
## Example Configurations
### Llama 2 7B on 4× A100 (40GB)
```bash
litgpt finetune_lora meta-llama/Llama-2-7b-hf \
--devices 4 \
--precision bf16-mixed \
--train.global_batch_size 64 \
--train.micro_batch_size 4 \
--train.max_seq_length 2048 \
--lora_r 8 \
--data JSON \
--data.json_path data/alpaca.json
```
**Memory per GPU**: ~20GB
**Throughput**: ~5 samples/sec
### Llama 2 70B on 8× A100 (80GB)
```bash
litgpt finetune_lora meta-llama/Llama-2-70b-hf \
--devices 8 \
--precision bf16-mixed \
--train.global_batch_size 32 \
--train.micro_batch_size 1 \
--train.max_seq_length 2048 \
--lora_r 8 \
--data JSON \
--data.json_path data/alpaca.json
```
**Memory per GPU**: ~70GB
**Throughput**: ~1 sample/sec
### Llama 3 405B on 64× H100 (80GB)
```bash
litgpt finetune_lora meta-llama/Llama-3.1-405B \
--devices 8 \
--num_nodes 8 \
--precision bf16-mixed \
--train.global_batch_size 128 \
--train.micro_batch_size 1 \
--train.max_seq_length 4096 \
--lora_r 16 \
--data JSON \
--data.json_path data/alpaca.json
```
**Memory per GPU**: ~60GB
**Requires**: 64 H100 GPUs (8 nodes × 8 GPUs)
## Troubleshooting
### "CUDA out of memory"
1. Reduce `micro_batch_size`
2. Increase `devices` (more sharding)
3. Lower `max_seq_length`
4. Use `bf16-mixed` precision
5. Try QLoRA (`--quantize bnb.nf4`)
### "NCCL error" or Slow Communication
1. Check network connectivity between nodes
2. Enable InfiniBand: `NCCL_IB_DISABLE=0`
3. Verify NCCL version: `python -c "import torch; print(torch.cuda.nccl.version())"`
4. Test with NCCL tests: `$NCCL_HOME/build/all_reduce_perf -b 8 -e 128M`
### Training Slower Than Expected
1. Profile with `--profile`
2. Check GPU utilization: `nvidia-smi dmon`
3. Verify data loading isn't bottleneck
4. Increase `micro_batch_size` if memory allows
5. Use Thunder FSDP with bucketing
## References
- FSDP configuration: `litgpt/pretrain.py:setup()`
- Thunder FSDP: `extensions/thunder/pretrain.py`
- Memory optimization guide: `tutorials/oom.md`
- Lightning Fabric docs: https://lightning.ai/docs/fabric/
@@ -0,0 +1,336 @@
# Supported Models
Complete list of model architectures supported by LitGPT with parameter sizes and variants.
## Overview
LitGPT supports **20+ model families** with **100+ model variants** ranging from 135M to 405B parameters.
**List all models**:
```bash
litgpt download list
```
**List pretrain-capable models**:
```bash
litgpt pretrain list
```
## Model Families
### Llama Family
**Llama 3, 3.1, 3.2, 3.3**:
- **Sizes**: 1B, 3B, 8B, 70B, 405B
- **Use Cases**: General-purpose, long-context (128K), multimodal
- **Best For**: Production applications, research, instruction following
**Code Llama**:
- **Sizes**: 7B, 13B, 34B, 70B
- **Use Cases**: Code generation, completion, infilling
- **Best For**: Programming assistants, code analysis
**Function Calling Llama 2**:
- **Sizes**: 7B
- **Use Cases**: Tool use, API integration
- **Best For**: Agents, function execution
**Llama 2**:
- **Sizes**: 7B, 13B, 70B
- **Use Cases**: General-purpose (predecessor to Llama 3)
- **Best For**: Established baselines, research comparisons
**Llama 3.1 Nemotron**:
- **Sizes**: 70B
- **Use Cases**: NVIDIA-optimized variant
- **Best For**: Enterprise deployments
**TinyLlama**:
- **Sizes**: 1.1B
- **Use Cases**: Edge devices, resource-constrained environments
- **Best For**: Fast inference, mobile deployment
**OpenLLaMA**:
- **Sizes**: 3B, 7B, 13B
- **Use Cases**: Open-source Llama reproduction
- **Best For**: Research, education
**Vicuna**:
- **Sizes**: 7B, 13B, 33B
- **Use Cases**: Chatbot, instruction following
- **Best For**: Conversational AI
**R1 Distill Llama**:
- **Sizes**: 8B, 70B
- **Use Cases**: Distilled reasoning models
- **Best For**: Efficient reasoning tasks
**MicroLlama**:
- **Sizes**: 300M
- **Use Cases**: Extremely small Llama variant
- **Best For**: Prototyping, testing
**Platypus**:
- **Sizes**: 7B, 13B, 70B
- **Use Cases**: STEM-focused fine-tune
- **Best For**: Science, math, technical domains
### Mistral Family
**Mistral**:
- **Sizes**: 7B, 123B
- **Use Cases**: Efficient open models, long-context
- **Best For**: Cost-effective deployments
**Mathstral**:
- **Sizes**: 7B
- **Use Cases**: Math reasoning
- **Best For**: Mathematical problem solving
**Mixtral MoE**:
- **Sizes**: 8×7B (47B total, 13B active), 8×22B (141B total, 39B active)
- **Use Cases**: Sparse mixture of experts
- **Best For**: High capacity with lower compute
### Falcon Family
**Falcon**:
- **Sizes**: 7B, 40B, 180B
- **Use Cases**: Open-source models from TII
- **Best For**: Multilingual applications
**Falcon 3**:
- **Sizes**: 1B, 3B, 7B, 10B
- **Use Cases**: Newer Falcon generation
- **Best For**: Efficient multilingual models
### Phi Family (Microsoft)
**Phi 1.5 & 2**:
- **Sizes**: 1.3B, 2.7B
- **Use Cases**: Small language models with strong performance
- **Best For**: Edge deployment, low-resource environments
**Phi 3 & 3.5**:
- **Sizes**: 3.8B
- **Use Cases**: Improved small models
- **Best For**: Mobile, browser-based applications
**Phi 4**:
- **Sizes**: 14B
- **Use Cases**: Medium-size high-performance model
- **Best For**: Balance of size and capability
**Phi 4 Mini Instruct**:
- **Sizes**: 3.8B
- **Use Cases**: Instruction-tuned variant
- **Best For**: Chat, task completion
### Gemma Family (Google)
**Gemma**:
- **Sizes**: 2B, 7B
- **Use Cases**: Google's open models
- **Best For**: Research, education
**Gemma 2**:
- **Sizes**: 2B, 9B, 27B
- **Use Cases**: Second generation improvements
- **Best For**: Enhanced performance
**Gemma 3**:
- **Sizes**: 1B, 4B, 12B, 27B
- **Use Cases**: Latest Gemma generation
- **Best For**: State-of-the-art open models
**CodeGemma**:
- **Sizes**: 7B
- **Use Cases**: Code-specialized Gemma
- **Best For**: Code generation, analysis
### Qwen Family (Alibaba)
**Qwen2.5**:
- **Sizes**: 0.5B, 1.5B, 3B, 7B, 14B, 32B, 72B
- **Use Cases**: General-purpose multilingual models
- **Best For**: Chinese/English applications
**Qwen2.5 Coder**:
- **Sizes**: 0.5B, 1.5B, 3B, 7B, 14B, 32B
- **Use Cases**: Code-specialized variants
- **Best For**: Programming in multiple languages
**Qwen2.5 Math**:
- **Sizes**: 1.5B, 7B, 72B
- **Use Cases**: Mathematical reasoning
- **Best For**: Math problems, STEM education
**QwQ & QwQ-Preview**:
- **Sizes**: 32B
- **Use Cases**: Question-answering focus
- **Best For**: Reasoning tasks
### Pythia Family (EleutherAI)
**Pythia**:
- **Sizes**: 14M, 31M, 70M, 160M, 410M, 1B, 1.4B, 2.8B, 6.9B, 12B
- **Use Cases**: Research, interpretability
- **Best For**: Scientific studies, ablations
### StableLM Family (Stability AI)
**StableLM**:
- **Sizes**: 3B, 7B
- **Use Cases**: Open models from Stability AI
- **Best For**: Research, commercial use
**StableLM Zephyr**:
- **Sizes**: 3B
- **Use Cases**: Instruction-tuned variant
- **Best For**: Chat applications
**StableCode**:
- **Sizes**: 3B
- **Use Cases**: Code generation
- **Best For**: Programming tasks
**FreeWilly2 (Stable Beluga 2)**:
- **Sizes**: 70B
- **Use Cases**: Large Stability AI model
- **Best For**: High-capability tasks
### Other Models
**Danube2**:
- **Sizes**: 1.8B
- **Use Cases**: Efficient small model
- **Best For**: Resource-constrained environments
**Dolly**:
- **Sizes**: 3B, 7B, 12B
- **Use Cases**: Databricks' instruction-following model
- **Best For**: Enterprise applications
**LongChat**:
- **Sizes**: 7B, 13B
- **Use Cases**: Extended context windows
- **Best For**: Long-document understanding
**Nous-Hermes**:
- **Sizes**: 7B, 13B, 70B
- **Use Cases**: Instruction-following fine-tune
- **Best For**: Task completion, reasoning
**OLMo**:
- **Sizes**: 1B, 7B
- **Use Cases**: Allen AI's fully open model
- **Best For**: Research transparency
**RedPajama-INCITE**:
- **Sizes**: 3B, 7B
- **Use Cases**: Open reproduction project
- **Best For**: Research, education
**Salamandra**:
- **Sizes**: 2B, 7B
- **Use Cases**: Multilingual European model
- **Best For**: European language support
**SmolLM2**:
- **Sizes**: 135M, 360M, 1.7B
- **Use Cases**: Ultra-small models
- **Best For**: Edge devices, testing
## Download Examples
**Download specific model**:
```bash
litgpt download meta-llama/Llama-3.2-1B
litgpt download microsoft/phi-2
litgpt download google/gemma-2-9b
```
**Download with HuggingFace token** (for gated models):
```bash
export HF_TOKEN=hf_...
litgpt download meta-llama/Llama-3.1-405B
```
## Model Selection Guide
### By Use Case
**General Chat/Instruction Following**:
- Small: Phi-2 (2.7B), TinyLlama (1.1B)
- Medium: Llama-3.2-8B, Mistral-7B
- Large: Llama-3.1-70B, Mixtral-8x22B
**Code Generation**:
- Small: Qwen2.5-Coder-3B
- Medium: CodeLlama-13B, CodeGemma-7B
- Large: CodeLlama-70B, Qwen2.5-Coder-32B
**Math/Reasoning**:
- Small: Qwen2.5-Math-1.5B
- Medium: Mathstral-7B, Qwen2.5-Math-7B
- Large: QwQ-32B, Qwen2.5-Math-72B
**Multilingual**:
- Small: SmolLM2-1.7B
- Medium: Qwen2.5-7B, Falcon-7B
- Large: Qwen2.5-72B
**Research/Education**:
- Pythia family (14M-12B for ablations)
- OLMo (fully open)
- TinyLlama (fast iteration)
### By Hardware
**Consumer GPU (8-16GB VRAM)**:
- Phi-2 (2.7B)
- TinyLlama (1.1B)
- Gemma-2B
- SmolLM2 family
**Single A100 (40-80GB)**:
- Llama-3.2-8B
- Mistral-7B
- CodeLlama-13B
- Gemma-9B
**Multi-GPU (200GB+ total)**:
- Llama-3.1-70B (TP=4)
- Mixtral-8x22B (TP=2)
- Falcon-40B
**Large Cluster**:
- Llama-3.1-405B (FSDP)
- Falcon-180B
## Model Capabilities
### Context Lengths
| Model | Context Window |
|-------|----------------|
| Llama 3.1 | 128K |
| Llama 3.2/3.3 | 128K |
| Mistral-123B | 128K |
| Mixtral | 32K |
| Gemma 2 | 8K |
| Phi-3 | 128K |
| Qwen2.5 | 32K |
### Training Data
- **Llama 3**: 15T tokens (multilingual)
- **Mistral**: Web data, code
- **Qwen**: Multilingual (Chinese/English focus)
- **Pythia**: The Pile (controlled training)
## References
- LitGPT GitHub: https://github.com/Lightning-AI/litgpt
- Model configs: `litgpt/config.py`
- Download tutorial: `tutorials/download_model_weights.md`
@@ -0,0 +1,619 @@
# Training Recipes
Complete hyperparameter configurations for LoRA, QLoRA, and full fine-tuning across different model sizes.
## Overview
LitGPT provides optimized training configurations in `config_hub/finetune/` for various model architectures and fine-tuning methods.
**Key Configuration Files**:
- `config_hub/finetune/*/lora.yaml` - LoRA fine-tuning
- `config_hub/finetune/*/qlora.yaml` - 4-bit quantized LoRA
- `config_hub/finetune/*/full.yaml` - Full fine-tuning
## LoRA Fine-tuning Recipes
### TinyLlama 1.1B LoRA
**Configuration**:
```yaml
global_batch_size: 8
micro_batch_size: 8
lr_warmup_steps: 10
epochs: 3
max_seq_length: 512
# LoRA specific
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
```
**Command**:
```bash
litgpt finetune_lora TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
--data JSON \
--data.json_path data/alpaca_sample.json \
--train.global_batch_size 8 \
--train.micro_batch_size 8 \
--train.lr_warmup_steps 10 \
--train.epochs 3 \
--train.max_seq_length 512 \
--lora_r 8 \
--lora_alpha 16
```
**Memory**: ~4GB VRAM
**Time**: ~30 minutes on RTX 3090
### Llama 2 7B LoRA
**Configuration**:
```yaml
global_batch_size: 8
micro_batch_size: 2
lr_warmup_steps: 10
epochs: 4
max_seq_length: 512
# LoRA specific
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
```
**Command**:
```bash
litgpt finetune_lora meta-llama/Llama-2-7b-hf \
--data JSON \
--data.json_path data/alpaca.json \
--train.global_batch_size 8 \
--train.micro_batch_size 2 \
--train.lr_warmup_steps 10 \
--train.epochs 4 \
--lora_r 8 \
--lora_alpha 16
```
**Memory**: ~16GB VRAM
**Gradient Accumulation**: 4 steps (8 / 2)
**Time**: ~6 hours on A100
### Llama 3 8B LoRA
**Configuration**:
```yaml
global_batch_size: 8
micro_batch_size: 1
lr_warmup_steps: 10
epochs: 2
max_seq_length: 512
# LoRA specific
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
```
**Command**:
```bash
litgpt finetune_lora meta-llama/Llama-3.2-8B \
--data JSON \
--data.json_path data/custom_dataset.json \
--train.global_batch_size 8 \
--train.micro_batch_size 1 \
--train.lr_warmup_steps 10 \
--train.epochs 2 \
--lora_r 8
```
**Memory**: ~20GB VRAM
**Gradient Accumulation**: 8 steps
**Time**: ~8 hours on A100
### Mistral 7B LoRA
**Configuration**:
```yaml
global_batch_size: 8
micro_batch_size: 2
lr_warmup_steps: 10
epochs: 4
max_seq_length: 512
lora_r: 8
lora_alpha: 16
```
**Command**:
```bash
litgpt finetune_lora mistralai/Mistral-7B-v0.1 \
--data JSON \
--data.json_path data/alpaca.json \
--train.global_batch_size 8 \
--train.micro_batch_size 2 \
--train.epochs 4 \
--lora_r 8
```
**Memory**: ~16GB VRAM
### Phi-2 LoRA
**Configuration**:
```yaml
global_batch_size: 8
micro_batch_size: 4
lr_warmup_steps: 10
epochs: 1
max_seq_length: 512
lora_r: 8
lora_alpha: 16
```
**Command**:
```bash
litgpt finetune_lora microsoft/phi-2 \
--data JSON \
--data.json_path data/alpaca_sample.json \
--train.global_batch_size 8 \
--train.micro_batch_size 4 \
--train.epochs 1 \
--lora_r 8
```
**Memory**: ~8GB VRAM
**Time**: ~20 minutes on RTX 3090
### Falcon 7B LoRA
**Configuration**:
```yaml
global_batch_size: 8
micro_batch_size: 1
lr_warmup_steps: 10
epochs: 4
max_seq_length: 512
lora_r: 8
lora_alpha: 16
```
**Command**:
```bash
litgpt finetune_lora tiiuae/falcon-7b \
--data JSON \
--data.json_path data/alpaca.json \
--train.global_batch_size 8 \
--train.micro_batch_size 1 \
--train.epochs 4 \
--lora_r 8
```
**Memory**: ~18GB VRAM
### Gemma 7B LoRA
**Configuration**:
```yaml
global_batch_size: 6
micro_batch_size: 1
lr_warmup_steps: 200
epochs: 2
max_seq_length: 512
lora_r: 8
lora_alpha: 16
```
**Command**:
```bash
litgpt finetune_lora google/gemma-7b \
--data JSON \
--data.json_path data/alpaca.json \
--train.global_batch_size 6 \
--train.micro_batch_size 1 \
--train.lr_warmup_steps 200 \
--train.epochs 2 \
--lora_r 8
```
**Memory**: ~18GB VRAM
**Note**: Longer warmup (200 steps) for stability
## QLoRA Fine-tuning Recipes
QLoRA uses 4-bit quantization to reduce memory by ~75%.
### TinyLlama 1.1B QLoRA
**Configuration**:
```yaml
global_batch_size: 8
micro_batch_size: 8
lr_warmup_steps: 10
epochs: 3
max_seq_length: 512
lora_r: 8
lora_alpha: 16
quantize: "bnb.nf4"
```
**Command**:
```bash
litgpt finetune_lora TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
--quantize bnb.nf4 \
--data JSON \
--data.json_path data/alpaca_sample.json \
--train.global_batch_size 8 \
--train.micro_batch_size 8 \
--train.epochs 3 \
--lora_r 8
```
**Memory**: ~2GB VRAM (75% reduction)
### Llama 2 7B QLoRA
**Configuration**:
```yaml
global_batch_size: 8
micro_batch_size: 2
lr_warmup_steps: 10
epochs: 4
max_seq_length: 512
min_lr: 6.0e-5
lora_r: 8
lora_alpha: 16
quantize: "bnb.nf4"
```
**Command**:
```bash
litgpt finetune_lora meta-llama/Llama-2-7b-hf \
--quantize bnb.nf4 \
--data JSON \
--data.json_path data/alpaca.json \
--train.global_batch_size 8 \
--train.micro_batch_size 2 \
--train.epochs 4 \
--lora_r 8
```
**Memory**: ~6GB VRAM (consumer GPU friendly)
### Llama 3 8B QLoRA
**Configuration**:
```yaml
global_batch_size: 8
micro_batch_size: 2
lr_warmup_steps: 10
epochs: 2
max_seq_length: 512
lora_r: 8
lora_alpha: 16
quantize: "bnb.nf4"
```
**Command**:
```bash
litgpt finetune_lora meta-llama/Llama-3.2-8B \
--quantize bnb.nf4 \
--data JSON \
--data.json_path data/custom_dataset.json \
--train.global_batch_size 8 \
--train.micro_batch_size 2 \
--train.epochs 2 \
--lora_r 8
```
**Memory**: ~8GB VRAM
### Mistral 7B QLoRA
**Configuration**:
```yaml
global_batch_size: 8
micro_batch_size: 2
lr_warmup_steps: 10
epochs: 4
max_seq_length: 512
lora_r: 8
lora_alpha: 16
quantize: "bnb.nf4"
```
**Memory**: ~6GB VRAM
### Phi-2 QLoRA
**Configuration**:
```yaml
global_batch_size: 8
micro_batch_size: 4
lr_warmup_steps: 10
epochs: 1
max_seq_length: 512
lora_r: 8
lora_alpha: 16
quantize: "bnb.nf4"
```
**Memory**: ~3GB VRAM
### Falcon 7B QLoRA
**Configuration**:
```yaml
global_batch_size: 8
micro_batch_size: 1
lr_warmup_steps: 10
epochs: 4
max_seq_length: 512
lora_r: 8
lora_alpha: 16
quantize: "bnb.nf4"
```
**Memory**: ~6GB VRAM
### Gemma 2B QLoRA
**Configuration**:
```yaml
global_batch_size: 6
micro_batch_size: 2
lr_warmup_steps: 200
epochs: 2
max_seq_length: 512
lora_r: 8
lora_alpha: 16
quantize: "bnb.nf4"
```
**Memory**: ~3GB VRAM
### Gemma 7B QLoRA
**Configuration**:
```yaml
global_batch_size: 6
micro_batch_size: 1
lr_warmup_steps: 200
epochs: 2
max_seq_length: 512
lora_r: 8
lora_alpha: 16
quantize: "bnb.nf4"
```
**Memory**: ~6GB VRAM
## Full Fine-tuning Recipes
Full fine-tuning updates all model parameters (requires more memory).
### TinyLlama 1.1B Full
**Configuration**:
```yaml
global_batch_size: 8
micro_batch_size: 2
lr_warmup_steps: 100
epochs: 3
max_seq_length: 512
learning_rate: 5e-5
```
**Command**:
```bash
litgpt finetune_full TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
--data JSON \
--data.json_path data/alpaca.json \
--train.global_batch_size 8 \
--train.micro_batch_size 2 \
--train.lr_warmup_steps 100 \
--train.epochs 3 \
--train.learning_rate 5e-5
```
**Memory**: ~12GB VRAM
**Time**: ~4 hours on A100
### Phi-2 Full
**Configuration**:
```yaml
global_batch_size: 8
micro_batch_size: 1
lr_warmup_steps: 100
epochs: 2
max_seq_length: 512
learning_rate: 3e-5
```
**Command**:
```bash
litgpt finetune_full microsoft/phi-2 \
--data JSON \
--data.json_path data/alpaca.json \
--train.global_batch_size 8 \
--train.micro_batch_size 1 \
--train.epochs 2 \
--train.learning_rate 3e-5
```
**Memory**: ~24GB VRAM
## Common Hyperparameter Patterns
### Learning Rates
| Model Size | LoRA LR | Full Fine-tune LR |
|------------|---------|-------------------|
| <2B | 3e-4 | 5e-5 |
| 2-10B | 1e-4 | 3e-5 |
| 10-70B | 5e-5 | 1e-5 |
### LoRA Rank (r)
- **r=8**: Default, good balance (recommended)
- **r=16**: More capacity, 2× trainable params
- **r=32**: Maximum capacity, slower training
- **r=4**: Minimal, fastest training
**Rule of thumb**: Start with r=8, increase if underfitting.
### Batch Sizes
| GPU VRAM | Micro Batch | Global Batch |
|----------|-------------|--------------|
| 8GB | 1 | 8 |
| 16GB | 2 | 8-16 |
| 40GB | 4 | 16-32 |
| 80GB | 8 | 32-64 |
### Warmup Steps
- **Small models (<2B)**: 10-50 steps
- **Medium models (2-10B)**: 100-200 steps
- **Large models (>10B)**: 200-500 steps
### Epochs
- **Instruction tuning**: 1-3 epochs
- **Domain adaptation**: 3-5 epochs
- **Small datasets (<10K)**: 5-10 epochs
## Advanced Configurations
### Custom Learning Rate Schedule
```bash
litgpt finetune_lora meta-llama/Llama-2-7b-hf \
--train.learning_rate 3e-4 \
--train.lr_warmup_steps 100 \
--train.min_lr 3e-6 \
--train.lr_decay_iters 10000
```
### Gradient Accumulation
```bash
# Simulate global_batch_size=128 with 16GB GPU
litgpt finetune_lora meta-llama/Llama-2-7b-hf \
--train.global_batch_size 128 \
--train.micro_batch_size 2
# Accumulates over 64 steps (128 / 2)
```
### Mixed Precision
```bash
litgpt finetune_lora meta-llama/Llama-2-7b-hf \
--precision bf16-mixed # BF16 mixed precision
# or
--precision 16-mixed # FP16 mixed precision
```
### Longer Context
```bash
litgpt finetune_lora meta-llama/Llama-3.1-8B \
--train.max_seq_length 8192 \
--train.micro_batch_size 1 # Reduce batch for memory
```
## Memory Optimization
### Out of Memory? Try These
1. **Enable quantization**:
```bash
--quantize bnb.nf4 # 4-bit QLoRA
```
2. **Reduce batch size**:
```bash
--train.micro_batch_size 1
```
3. **Lower LoRA rank**:
```bash
--lora_r 4 # Instead of 8
```
4. **Use FSDP** (multi-GPU):
```bash
litgpt finetune_lora meta-llama/Llama-2-7b-hf \
--devices 4 # Use 4 GPUs with FSDP
```
5. **Gradient checkpointing**:
```bash
--train.gradient_accumulation_iters 16
```
## Data Format
LitGPT expects JSON data in instruction format:
```json
[
{
"instruction": "What is the capital of France?",
"input": "",
"output": "The capital of France is Paris."
},
{
"instruction": "Translate to Spanish:",
"input": "Hello world",
"output": "Hola mundo"
}
]
```
**Load custom data**:
```bash
litgpt finetune_lora meta-llama/Llama-2-7b-hf \
--data JSON \
--data.json_path data/my_dataset.json \
--data.val_split_fraction 0.1 # 10% validation
```
## Merge and Deploy
After fine-tuning, merge LoRA weights:
```bash
litgpt merge_lora checkpoints/meta-llama/Llama-2-7b-hf/final_lora.pth
```
Generate with merged model:
```bash
litgpt generate checkpoints/meta-llama/Llama-2-7b-hf-merged/ \
--prompt "What is machine learning?"
```
Or serve via API:
```bash
litgpt serve checkpoints/meta-llama/Llama-2-7b-hf-merged/
```
## References
- Configuration hub: `config_hub/finetune/`
- Fine-tuning tutorial: `tutorials/finetune_*.md`
- Memory guide: `tutorials/oom.md`
@@ -0,0 +1,260 @@
---
name: mamba-architecture
description: State-space model with O(n) complexity vs Transformers' O(n²). 5× faster inference, million-token sequences, no KV cache. Selective SSM with hardware-aware design. Mamba-1 (d_state=16) and Mamba-2 (d_state=128, multi-head). Models 130M-2.8B on HuggingFace.
version: 1.0.0
author: Orchestra Research
license: MIT
tags: [Model Architecture, Mamba, State Space Models, SSM, Linear Complexity, Long Context, Efficient Inference, Hardware-Aware, Alternative To Transformers]
dependencies: [mamba-ssm, torch, transformers, causal-conv1d]
---
# Mamba - Selective State Space Models
## Quick start
Mamba is a state-space model architecture achieving O(n) linear complexity for sequence modeling.
**Installation**:
```bash
# Install causal-conv1d (optional, for efficiency)
pip install causal-conv1d>=1.4.0
# Install Mamba
pip install mamba-ssm
# Or both together
pip install mamba-ssm[causal-conv1d]
```
**Prerequisites**: Linux, NVIDIA GPU, PyTorch 1.12+, CUDA 11.6+
**Basic usage** (Mamba block):
```python
import torch
from mamba_ssm import Mamba
batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
d_model=dim, # Model dimension
d_state=16, # SSM state dimension
d_conv=4, # Conv1d kernel size
expand=2 # Expansion factor
).to("cuda")
y = model(x) # O(n) complexity!
assert y.shape == x.shape
```
## Common workflows
### Workflow 1: Language model with Mamba-2
**Complete LM with generation**:
```python
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from mamba_ssm.models.config_mamba import MambaConfig
import torch
# Configure Mamba-2 LM
config = MambaConfig(
d_model=1024, # Hidden dimension
n_layer=24, # Number of layers
vocab_size=50277, # Vocabulary size
ssm_cfg=dict(
layer="Mamba2", # Use Mamba-2
d_state=128, # Larger state for Mamba-2
headdim=64, # Head dimension
ngroups=1 # Number of groups
)
)
model = MambaLMHeadModel(config, device="cuda", dtype=torch.float16)
# Generate text
input_ids = torch.randint(0, 1000, (1, 20), device="cuda", dtype=torch.long)
output = model.generate(
input_ids=input_ids,
max_length=100,
temperature=0.7,
top_p=0.9
)
```
### Workflow 2: Use pretrained Mamba models
**Load from HuggingFace**:
```python
from transformers import AutoTokenizer
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
# Load pretrained model
model_name = "state-spaces/mamba-2.8b"
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") # Use compatible tokenizer
model = MambaLMHeadModel.from_pretrained(model_name, device="cuda", dtype=torch.float16)
# Generate
prompt = "The future of AI is"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
output_ids = model.generate(
input_ids=input_ids,
max_length=200,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.2
)
generated_text = tokenizer.decode(output_ids[0])
print(generated_text)
```
**Available models**:
- `state-spaces/mamba-130m`
- `state-spaces/mamba-370m`
- `state-spaces/mamba-790m`
- `state-spaces/mamba-1.4b`
- `state-spaces/mamba-2.8b`
### Workflow 3: Mamba-1 vs Mamba-2
**Mamba-1** (smaller state):
```python
from mamba_ssm import Mamba
model = Mamba(
d_model=256,
d_state=16, # Smaller state dimension
d_conv=4,
expand=2
).to("cuda")
```
**Mamba-2** (multi-head, larger state):
```python
from mamba_ssm import Mamba2
model = Mamba2(
d_model=256,
d_state=128, # Larger state dimension
d_conv=4,
expand=2,
headdim=64, # Head dimension for multi-head
ngroups=1 # Parallel groups
).to("cuda")
```
**Key differences**:
- **State size**: Mamba-1 (d_state=16) vs Mamba-2 (d_state=128)
- **Architecture**: Mamba-2 has multi-head structure
- **Normalization**: Mamba-2 uses RMSNorm
- **Distributed**: Mamba-2 supports tensor parallelism
### Workflow 4: Benchmark vs Transformers
**Generation speed comparison**:
```bash
# Benchmark Mamba
python benchmarks/benchmark_generation_mamba_simple.py \
--model-name "state-spaces/mamba-2.8b" \
--prompt "The future of machine learning is" \
--topp 0.9 --temperature 0.7 --repetition-penalty 1.2
# Benchmark Transformer
python benchmarks/benchmark_generation_mamba_simple.py \
--model-name "EleutherAI/pythia-2.8b" \
--prompt "The future of machine learning is" \
--topp 0.9 --temperature 0.7 --repetition-penalty 1.2
```
**Expected results**:
- **Mamba**: 5× faster inference
- **Memory**: No KV cache needed
- **Scaling**: Linear with sequence length
## When to use vs alternatives
**Use Mamba when**:
- Need long sequences (100K+ tokens)
- Want faster inference than Transformers
- Memory-constrained (no KV cache)
- Building streaming applications
- Linear scaling important
**Advantages**:
- **O(n) complexity**: Linear vs quadratic
- **5× faster inference**: No attention overhead
- **No KV cache**: Lower memory usage
- **Million-token sequences**: Hardware-efficient
- **Streaming**: Constant memory per token
**Use alternatives instead**:
- **Transformers**: Need best-in-class performance, have compute
- **RWKV**: Want RNN+Transformer hybrid
- **RetNet**: Need retention-based architecture
- **Hyena**: Want convolution-based approach
## Common issues
**Issue: CUDA out of memory**
Reduce batch size or use gradient checkpointing:
```python
model = MambaLMHeadModel(config, device="cuda", dtype=torch.float16)
model.gradient_checkpointing_enable() # Enable checkpointing
```
**Issue: Slow installation**
Install binary wheels (not source):
```bash
pip install mamba-ssm --no-build-isolation
```
**Issue: Missing causal-conv1d**
Install separately:
```bash
pip install causal-conv1d>=1.4.0
```
**Issue: Model not loading from HuggingFace**
Use `MambaLMHeadModel.from_pretrained` (not `AutoModel`):
```python
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b")
```
## Advanced topics
**Selective SSM**: See [references/selective-ssm.md](references/selective-ssm.md) for mathematical formulation, state-space equations, and how selectivity enables O(n) complexity.
**Mamba-2 architecture**: See [references/mamba2-details.md](references/mamba2-details.md) for multi-head structure, tensor parallelism, and distributed training setup.
**Performance optimization**: See [references/performance.md](references/performance.md) for hardware-aware design, CUDA kernels, and memory efficiency techniques.
## Hardware requirements
- **GPU**: NVIDIA with CUDA 11.6+
- **VRAM**:
- 130M model: 2GB
- 370M model: 4GB
- 790M model: 8GB
- 1.4B model: 14GB
- 2.8B model: 28GB (FP16)
- **Inference**: 5× faster than Transformers
- **Memory**: No KV cache (lower than Transformers)
**Performance** (vs Transformers):
- **Speed**: 5× faster inference
- **Memory**: 50% less (no KV cache)
- **Scaling**: Linear vs quadratic
## Resources
- Paper (Mamba-1): https://arxiv.org/abs/2312.00752 (Dec 2023)
- Paper (Mamba-2): https://arxiv.org/abs/2405.21060 (May 2024)
- GitHub: https://github.com/state-spaces/mamba ⭐ 13,000+
- Models: https://huggingface.co/state-spaces
- Docs: Repository README and wiki
@@ -0,0 +1,206 @@
# Mamba Architecture Details
## Selective State Space Mechanism
Mamba's core innovation is the **Selective SSM (S6)** layer that makes state space model parameters input-dependent.
### How S6 Works
**Traditional SSMs** (non-selective):
```python
# Fixed A, B, C matrices for all inputs
h(t) = A * h(t-1) + B * x(t) # State update
y(t) = C * h(t) # Output
```
**Mamba's Selective SSM**:
```python
# Input-dependent parameters
B(t) = Linear_B(x(t)) # Selection mechanism
C(t) = Linear_C(x(t)) # Output projection
Δ(t) = Linear_Δ(x(t)) # Discretization step
# Selective state update
h(t) = discretize(A, Δ(t)) * h(t-1) + Δ(t) * B(t) * x(t)
y(t) = C(t) * h(t)
```
### Key Advantages
**1. Content-based reasoning**:
- Can selectively remember or forget based on input
- Addresses discrete modality weakness of traditional SSMs
- Example: Remembers important tokens, forgets padding
**2. Input-dependent selection**:
```python
# Mamba decides per token what to remember
if is_important(x(t)):
Δ(t) = large_value # Keep in state
else:
Δ(t) = small_value # Forget quickly
```
**3. No attention required**:
- Replaces O(n²) attention with O(n) state updates
- State dimension is constant (typically 16)
## Model Configuration
### Core Parameters
```python
from mamba_ssm import Mamba
model = Mamba(
d_model=256, # Hidden dimension (256, 512, 768, 1024, 2048)
d_state=16, # SSM state dimension (fixed at 16 is optimal)
d_conv=4, # Local convolution width (4 is standard)
expand=2, # Expansion factor (1.5-2.0)
dt_rank="auto", # Rank of Δ projection (auto = d_model / 16)
dt_min=0.001, # Min Δ init (controls forgetting rate)
dt_max=0.1, # Max Δ init
dt_init="random", # Δ initialization (random, constant)
dt_scale=1.0, # Δ scaling factor
conv_bias=True, # Use bias in convolution
bias=False # Use bias in linear projections
)
```
### Parameter Impact
**d_state** (SSM state dimension):
- Standard: 16 (optimal from ablations)
- Smaller (8): Faster but less capacity
- Larger (32, 64): Minimal improvement, 2× slower
**expand** (block expansion):
- Standard: 2.0
- Range: 1.5-2.0
- Controls inner dimension = expand * d_model
**d_conv** (convolution width):
- Standard: 4
- Local context window before SSM
- Helps with positional information
**dt_rank** (Δ projection rank):
- Auto: d_model / 16 (recommended)
- Controls Δ parameter efficiency
- Lower rank = more efficient but less expressive
## Mamba Block Structure
```python
# Mamba block (replaces Transformer block)
class MambaBlock(nn.Module):
def __init__(self, d_model):
self.norm = RMSNorm(d_model)
self.mamba = Mamba(d_model, d_state=16, d_conv=4, expand=2)
def forward(self, x):
return x + self.mamba(self.norm(x)) # Residual
# Full model (stack of Mamba blocks)
model = nn.Sequential(
Embedding(...),
*[MambaBlock(d_model) for _ in range(n_layers)],
RMSNorm(d_model),
LMHead(...)
)
```
**Key differences from Transformers**:
- No multi-head attention (MHA)
- No feedforward network (FFN)
- Single Mamba layer per block
- 2× more layers than equivalent Transformer
## Hardware-Aware Implementation
### Parallel Algorithm
Mamba uses a **scan-based parallel algorithm** for training:
```python
# Parallel mode (training)
# GPU kernel fuses operations
y = parallel_scan(A, B, C, x) # O(n log n) parallel
# Sequential mode (inference)
# Constant memory RNN-style
h = 0
for x_t in sequence:
h = A*h + B*x_t
y_t = C*h
```
### Memory Efficiency
**Training**:
- Recomputes activations in backward pass
- Similar to FlashAttention strategy
- Memory: O(batch_size * seq_len * d_model)
**Inference**:
- RNN-style sequential processing
- State size: O(d_model * d_state) = constant
- No KV cache needed (huge advantage!)
### CUDA Kernel Optimizations
```python
# Fused kernel operations
- Discretization (continuous discrete A, B)
- SSM recurrence (parallel scan)
- Convolution (efficient 1D conv)
- All in single GPU kernel
```
## Layer Count Scaling
Mamba models use **2× layers** compared to Transformers:
| Model | d_model | n_layers | Params |
|-------|---------|----------|--------|
| Mamba-130M | 768 | 24 | 130M |
| Mamba-370M | 1024 | 48 | 370M |
| Mamba-790M | 1536 | 48 | 790M |
| Mamba-1.4B | 2048 | 48 | 1.4B |
| Mamba-2.8B | 2560 | 64 | 2.8B |
**Why 2× layers?**
- Mamba blocks are simpler (no MHA, no FFN)
- ~50% fewer parameters per layer
- Doubling layers matches compute budget
## Initialization Strategy
```python
# Δ (discretization step) initialization
dt_init_floor = 1e-4
dt = torch.exp(
torch.rand(d_inner) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
# A (state transition) initialization
A = -torch.exp(torch.rand(d_inner, d_state)) # Negative for stability
# B, C (input/output) initialization
B = torch.randn(d_inner, d_state)
C = torch.randn(d_inner, d_state)
```
**Critical for stability**:
- A must be negative (exponential decay)
- Δ in range [dt_min, dt_max]
- Random initialization helps diversity
## Resources
- Paper: https://arxiv.org/abs/2312.00752 (Mamba-1)
- Paper: https://arxiv.org/abs/2405.21060 (Mamba-2)
- GitHub: https://github.com/state-spaces/mamba
- Models: https://huggingface.co/state-spaces
- CUDA kernels: https://github.com/state-spaces/mamba/tree/main/csrc
@@ -0,0 +1,255 @@
# Mamba Performance Benchmarks
## Inference Speed Comparison
### Throughput (tokens/sec)
**Mamba-1.4B vs Transformer-1.3B** on single A100 80GB:
| Sequence Length | Mamba-1.4B | Transformer-1.3B | Speedup |
|----------------|------------|------------------|---------|
| 512 | 8,300 | 6,200 | 1.3× |
| 1024 | 7,800 | 4,100 | 1.9× |
| 2048 | 7,200 | 2,300 | 3.1× |
| 4096 | 6,800 | 1,200 | 5.7× |
| 8192 | 6,400 | 600 | **10.7×** |
| 16384 | 6,100 | OOM | ∞ |
**Key insight**: Speedup grows with sequence length (Mamba O(n) vs Transformer O(n²))
### Latency (ms per token)
**Generation latency** (batch size 1, autoregressive):
| Model | First Token | Per Token | 100 Tokens Total |
|-------|-------------|-----------|------------------|
| Mamba-130M | 3 ms | 0.8 ms | 83 ms |
| Transformer-130M | 5 ms | 1.2 ms | 125 ms |
| Mamba-1.4B | 12 ms | 3.2 ms | 332 ms |
| Transformer-1.3B | 18 ms | 8.5 ms | 868 ms |
| Mamba-2.8B | 20 ms | 6.1 ms | 631 ms |
| Transformer-2.7B | 35 ms | 18.2 ms | 1855 ms |
**Mamba advantage**: Constant per-token latency regardless of context length
## Memory Usage
### Training Memory (BF16, per GPU)
**Mamba-1.4B** training memory breakdown:
| Sequence Length | Activations | Gradients | Optimizer | Total | vs Transformer |
|----------------|-------------|-----------|-----------|-------|----------------|
| 512 | 2.1 GB | 3.2 GB | 11.2 GB | 16.5 GB | 0.9× |
| 1024 | 3.8 GB | 3.2 GB | 11.2 GB | 18.2 GB | 0.6× |
| 2048 | 7.2 GB | 3.2 GB | 11.2 GB | 21.6 GB | 0.4× |
| 4096 | 14.1 GB | 3.2 GB | 11.2 GB | 28.5 GB | 0.25× |
| 8192 | 28.0 GB | 3.2 GB | 11.2 GB | 42.4 GB | 0.15× |
**Note**: Transformer OOMs at 8K sequence length on 40GB A100
### Inference Memory (FP16, batch size 1)
| Model | KV Cache (8K ctx) | State (Mamba) | Ratio |
|-------|------------------|---------------|-------|
| 130M | 2.1 GB | 0 MB | ∞ |
| 370M | 5.2 GB | 0 MB | ∞ |
| 1.4B | 19.7 GB | 0 MB | ∞ |
| 2.8B | 38.4 GB | 0 MB | ∞ |
**Mamba stores no KV cache** - constant memory per token!
Actual Mamba state size:
- 130M: ~3 MB (d_model × d_state × n_layers = 768 × 16 × 24)
- 2.8B: ~13 MB (2560 × 16 × 64)
## Language Modeling Benchmarks
### Perplexity on Common Datasets
**Models trained on The Pile (300B tokens)**:
| Model | Params | Pile (val) | WikiText-103 | C4 | Lambada |
|-------|--------|------------|--------------|-----|---------|
| Pythia | 160M | 29.6 | 28.4 | 23.1 | 51.2 |
| **Mamba** | **130M** | **28.1** | **26.7** | **21.8** | **48.3** |
| Pythia | 410M | 18.3 | 17.6 | 16.2 | 32.1 |
| **Mamba** | **370M** | **16.7** | **16.2** | **15.1** | **28.4** |
| Pythia | 1.4B | 10.8 | 10.2 | 11.3 | 15.2 |
| **Mamba** | **1.4B** | **9.1** | **9.6** | **10.1** | **12.8** |
| Pythia | 2.8B | 8.3 | 7.9 | 9.2 | 10.6 |
| **Mamba** | **2.8B** | **7.4** | **7.2** | **8.3** | **9.1** |
**Mamba consistently outperforms** Transformers of similar size by 10-20%
### Zero-Shot Task Performance
**Mamba-2.8B vs Transformer-2.7B** on common benchmarks:
| Task | Mamba-2.8B | Transformer-2.7B | Delta |
|------|------------|------------------|-------|
| HellaSwag | 61.3 | 58.7 | +2.6 |
| PIQA | 78.1 | 76.4 | +1.7 |
| ARC-Easy | 68.2 | 65.9 | +2.3 |
| ARC-Challenge | 42.7 | 40.1 | +2.6 |
| WinoGrande | 64.8 | 62.3 | +2.5 |
| OpenBookQA | 43.2 | 41.8 | +1.4 |
| BoolQ | 71.4 | 68.2 | +3.2 |
| MMLU (5-shot) | 35.2 | 33.8 | +1.4 |
**Average improvement**: +2.2 points across benchmarks
## Audio Modeling Benchmarks
### SC09 (Speech Commands)
**Task**: Audio classification (10 classes)
| Model | Params | Accuracy | Inference (ms) |
|-------|--------|----------|----------------|
| Transformer | 8.2M | 96.2% | 18 ms |
| S4 | 6.1M | 97.1% | 8 ms |
| **Mamba** | **6.3M** | **98.4%** | **6 ms** |
### LJSpeech (Speech Generation)
**Task**: Text-to-speech quality (MOS score)
| Model | Params | MOS ↑ | RTF ↓ |
|-------|--------|-------|-------|
| Transformer | 12M | 3.82 | 0.45 |
| Conformer | 11M | 3.91 | 0.38 |
| **Mamba** | **10M** | **4.03** | **0.21** |
**RTF** (Real-Time Factor): Lower is better (0.21 = 5× faster than real-time)
## Genomics Benchmarks
### Human Reference Genome (HG38)
**Task**: Next nucleotide prediction
| Model | Context Length | Perplexity | Throughput |
|-------|----------------|------------|------------|
| Transformer | 1024 | 3.21 | 1,200 bp/s |
| Hyena | 32768 | 2.87 | 8,500 bp/s |
| **Mamba** | **1M** | **2.14** | **45,000 bp/s** |
**Mamba handles million-length sequences** efficiently
## Scaling Laws
### Compute-Optimal Training
**FLOPs vs perplexity** (The Pile validation):
| Model Size | Training FLOPs | Mamba Perplexity | Transformer Perplexity |
|------------|----------------|------------------|------------------------|
| 130M | 6e19 | 28.1 | 29.6 |
| 370M | 3e20 | 16.7 | 18.3 |
| 790M | 8e20 | 12.3 | 13.9 |
| 1.4B | 2e21 | 9.1 | 10.8 |
| 2.8B | 6e21 | 7.4 | 8.3 |
**Scaling coefficient**: Mamba achieves same perplexity as Transformer with **0.8×** compute
### Parameter Efficiency
**Perplexity 10.0 target** on The Pile:
| Model Type | Parameters Needed | Memory (inference) |
|------------|-------------------|-------------------|
| Transformer | 1.6B | 3.2 GB |
| **Mamba** | **1.1B** | **2.2 GB** |
**Mamba needs ~30% fewer parameters** for same performance
## Long-Range Arena (LRA)
**Task**: Long-context understanding benchmarks
| Task | Length | Transformer | S4 | Mamba |
|------|--------|-------------|-----|-------|
| ListOps | 2K | 36.4% | 59.6% | **61.2%** |
| Text | 4K | 64.3% | 86.8% | **88.1%** |
| Retrieval | 4K | 57.5% | 90.9% | **92.3%** |
| Image | 1K | 42.4% | 88.7% | **89.4%** |
| PathFinder | 1K | 71.4% | 86.1% | **87.8%** |
| Path-X | 16K | OOM | 88.3% | **91.2%** |
**Average**: Mamba 85.0%, S4 83.4%, Transformer 54.4%
## Training Throughput
### Tokens/sec During Training
**8× A100 80GB** cluster, BF16, different sequence lengths:
| Model | Seq Len 512 | Seq Len 2K | Seq Len 8K | Seq Len 32K |
|-------|-------------|------------|------------|-------------|
| Transformer-1.3B | 180K | 52K | OOM | OOM |
| **Mamba-1.4B** | **195K** | **158K** | **121K** | **89K** |
| Transformer-2.7B | 92K | 26K | OOM | OOM |
| **Mamba-2.8B** | **98K** | **81K** | **62K** | **45K** |
**Mamba scales to longer sequences** without OOM
## Hardware Utilization
### GPU Memory Bandwidth
**Mamba-1.4B** inference on different GPUs:
| GPU | Memory BW | Tokens/sec | Efficiency |
|-----|-----------|------------|------------|
| A100 80GB | 2.0 TB/s | 6,800 | 85% |
| A100 40GB | 1.6 TB/s | 5,400 | 84% |
| V100 32GB | 900 GB/s | 3,100 | 86% |
| RTX 4090 | 1.0 TB/s | 3,600 | 90% |
**High efficiency**: Mamba is memory-bandwidth bound (good!)
### Multi-GPU Scaling
**Mamba-2.8B** training throughput:
| GPUs | Tokens/sec | Scaling Efficiency |
|------|------------|-------------------|
| 1× A100 | 12,300 | 100% |
| 2× A100 | 23,800 | 97% |
| 4× A100 | 46,100 | 94% |
| 8× A100 | 89,400 | 91% |
| 16× A100 | 172,000 | 88% |
**Near-linear scaling** up to 16 GPUs
## Cost Analysis
### Training Cost (USD)
**Training to The Pile perplexity 10.0** on cloud GPUs:
| Model | Cloud GPUs | Hours | Cost (A100) | Cost (H100) |
|-------|------------|-------|-------------|-------------|
| Transformer-1.6B | 8× A100 | 280 | $8,400 | $4,200 |
| **Mamba-1.1B** | **8× A100** | **180** | **$5,400** | **$2,700** |
**Savings**: 36% cost reduction vs Transformer
### Inference Cost (USD/million tokens)
**API-style inference** (batch size 1, 2K context):
| Model | Latency | Cost/M tokens | Quality (perplexity) |
|-------|---------|---------------|---------------------|
| Transformer-1.3B | 8.5 ms/tok | $0.42 | 10.8 |
| **Mamba-1.4B** | **3.2 ms/tok** | **$0.18** | **9.1** |
**Mamba provides**: 2.6× faster, 57% cheaper, better quality
## Resources
- Benchmarks code: https://github.com/state-spaces/mamba/tree/main/benchmarks
- Paper (Mamba-1): https://arxiv.org/abs/2312.00752 (Section 4: Experiments)
- Paper (Mamba-2): https://arxiv.org/abs/2405.21060 (Section 5: Experiments)
- Pretrained models: https://huggingface.co/state-spaces
@@ -0,0 +1,388 @@
# Mamba Training Guide
## Training from Scratch
### Setup Environment
```bash
# Install dependencies
pip install torch>=1.12.0 --extra-index-url https://download.pytorch.org/whl/cu116
pip install packaging ninja
pip install causal-conv1d>=1.1.0
pip install mamba-ssm
# Verify CUDA
python -c "import torch; print(torch.cuda.is_available())"
```
### Basic Training Loop
```python
import torch
from mamba_ssm import Mamba
from torch.utils.data import DataLoader
# Model setup
model = Mamba(
d_model=512,
d_state=16,
d_conv=4,
expand=2
).cuda()
# Optimizer (same as GPT)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=6e-4,
betas=(0.9, 0.95),
weight_decay=0.1
)
# Training loop
for batch in dataloader:
inputs, targets = batch
inputs, targets = inputs.cuda(), targets.cuda()
# Forward
logits = model(inputs)
loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))
# Backward
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
```
## Distributed Training
### Single-Node Multi-GPU (DDP)
```python
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
# Initialize process group
dist.init_process_group("nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
# Wrap model
model = Mamba(...).cuda()
model = DDP(model, device_ids=[local_rank])
# Train
optimizer = torch.optim.AdamW(model.parameters(), lr=6e-4)
for batch in dataloader:
loss = compute_loss(model, batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
```
**Launch**:
```bash
torchrun --nproc_per_node=8 train.py
```
### Multi-Node Training
```bash
# Node 0 (master)
torchrun --nproc_per_node=8 \
--nnodes=4 --node_rank=0 \
--master_addr=$MASTER_ADDR --master_port=29500 \
train.py
# Node 1-3 (workers)
torchrun --nproc_per_node=8 \
--nnodes=4 --node_rank=$NODE_RANK \
--master_addr=$MASTER_ADDR --master_port=29500 \
train.py
```
## Mixed Precision Training
### BF16 (Recommended)
```python
from torch.cuda.amp import autocast, GradScaler
# BF16 (no scaler needed on A100/H100)
for batch in dataloader:
with autocast(dtype=torch.bfloat16):
logits = model(inputs)
loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
```
### FP16 (with gradient scaling)
```python
scaler = GradScaler()
for batch in dataloader:
with autocast(dtype=torch.float16):
logits = model(inputs)
loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
```
## Hyperparameter Recommendations
### Learning Rate Schedule
```python
# Cosine decay with warmup (GPT-3 style)
def get_lr(it, warmup_iters=2000, lr_decay_iters=600000):
max_lr = 6e-4
min_lr = 6e-5
# Warmup
if it < warmup_iters:
return max_lr * it / warmup_iters
# Decay
if it > lr_decay_iters:
return min_lr
# Cosine
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
return min_lr + coeff * (max_lr - min_lr)
# Apply in training loop
for it, batch in enumerate(dataloader):
lr = get_lr(it)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
```
### Batch Size Recommendations
| Model Size | Per-GPU Batch | Gradient Accum | Effective Batch | GPUs |
|------------|---------------|----------------|-----------------|------|
| 130M | 32 | 4 | 1024 | 8 |
| 370M | 16 | 8 | 1024 | 8 |
| 790M | 8 | 8 | 512 | 8 |
| 1.4B | 4 | 16 | 512 | 8 |
| 2.8B | 2 | 16 | 256 | 8 |
```python
# Gradient accumulation
accumulation_steps = 8
optimizer.zero_grad()
for i, batch in enumerate(dataloader):
loss = compute_loss(model, batch) / accumulation_steps
loss.backward()
if (i + 1) % accumulation_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
```
### Optimizer Configuration
```python
# AdamW (recommended)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=6e-4, # Peak learning rate
betas=(0.9, 0.95), # Standard for LLMs
eps=1e-8,
weight_decay=0.1 # Important for generalization
)
# Weight decay exemptions (optional)
decay = set()
no_decay = set()
for name, param in model.named_parameters():
if 'norm' in name or 'bias' in name:
no_decay.add(param)
else:
decay.add(param)
optimizer = torch.optim.AdamW([
{'params': list(decay), 'weight_decay': 0.1},
{'params': list(no_decay), 'weight_decay': 0.0}
], lr=6e-4, betas=(0.9, 0.95))
```
## Memory Optimization
### Gradient Checkpointing
```python
from torch.utils.checkpoint import checkpoint
class MambaBlock(nn.Module):
def __init__(self, d_model, use_checkpoint=False):
super().__init__()
self.use_checkpoint = use_checkpoint
self.norm = RMSNorm(d_model)
self.mamba = Mamba(d_model)
def forward(self, x):
if self.use_checkpoint and self.training:
return x + checkpoint(self._forward, x, use_reentrant=False)
return x + self._forward(x)
def _forward(self, x):
return self.mamba(self.norm(x))
# Enable for training
model = MambaLM(use_checkpoint=True)
```
**Memory savings**: ~30-40% with minimal speed impact
### Flash Attention Integration
Mamba's CUDA kernels already use flash-attention-style optimizations:
- Fused operations in single kernel
- Recomputation in backward pass
- No intermediate activation storage
## Long Context Training
### Sequence Length Progression
```python
# Start short, increase gradually
training_stages = [
{'seq_len': 512, 'iters': 50000},
{'seq_len': 1024, 'iters': 100000},
{'seq_len': 2048, 'iters': 150000},
{'seq_len': 4096, 'iters': 200000},
]
for stage in training_stages:
dataloader = create_dataloader(seq_len=stage['seq_len'])
train(model, dataloader, max_iters=stage['iters'])
```
### Memory Requirements (Batch Size 1)
| Sequence Length | 130M Model | 370M Model | 1.4B Model |
|----------------|------------|------------|------------|
| 2K | 4 GB | 8 GB | 24 GB |
| 4K | 5 GB | 10 GB | 32 GB |
| 8K | 6 GB | 14 GB | 48 GB |
| 16K | 8 GB | 20 GB | 64 GB |
| 32K | 12 GB | 32 GB | 96 GB |
**Mamba advantage**: Memory grows **linearly**, Transformers grow **quadratically**
## Common Training Issues
### Issue: OOM during training
**Solution 1**: Reduce batch size
```python
per_gpu_batch = 8 # Reduce from 16
gradient_accumulation = 8 # Increase from 4
```
**Solution 2**: Enable gradient checkpointing
```python
model = MambaLM(use_checkpoint=True)
```
**Solution 3**: Use smaller sequence length
```python
seq_len = 1024 # Reduce from 2048
```
### Issue: Training unstable (loss spikes)
**Solution 1**: Check gradient norm
```python
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
print(f"Grad norm: {grad_norm}") # Should be < 10
```
**Solution 2**: Lower learning rate
```python
max_lr = 3e-4 # Reduce from 6e-4
```
**Solution 3**: Check Δ initialization
```python
# Ensure dt_min, dt_max are reasonable
model = Mamba(
d_model=512,
dt_min=0.001, # Not too small
dt_max=0.1 # Not too large
)
```
### Issue: Slow training speed
**Solution 1**: Verify CUDA kernels installed
```python
import mamba_ssm
print(mamba_ssm.__version__) # Should have CUDA kernels
```
**Solution 2**: Use BF16 on A100/H100
```python
with autocast(dtype=torch.bfloat16): # Faster than FP16
loss = model(inputs)
```
**Solution 3**: Increase batch size if possible
```python
per_gpu_batch = 16 # Increase from 8 (better GPU utilization)
```
## Checkpointing
### Save/Load Model
```python
# Save
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'iter': iteration,
'config': model_config
}
torch.save(checkpoint, f'checkpoint_{iteration}.pt')
# Load
checkpoint = torch.load('checkpoint_100000.pt')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
iteration = checkpoint['iter']
```
### Best Practices
```python
# Save every N iterations
if iteration % save_interval == 0:
save_checkpoint(model, optimizer, iteration)
# Keep only last K checkpoints
checkpoints = sorted(glob.glob('checkpoint_*.pt'))
if len(checkpoints) > keep_last:
for ckpt in checkpoints[:-keep_last]:
os.remove(ckpt)
```
## Resources
- Training code: https://github.com/state-spaces/mamba/tree/main/benchmarks
- Pretrained models: https://huggingface.co/state-spaces
- CUDA installation: https://github.com/state-spaces/mamba#installation
@@ -0,0 +1,290 @@
---
name: nanogpt
description: Educational GPT implementation in ~300 lines. Reproduces GPT-2 (124M) on OpenWebText. Clean, hackable code for learning transformers. By Andrej Karpathy. Perfect for understanding GPT architecture from scratch. Train on Shakespeare (CPU) or OpenWebText (multi-GPU).
version: 1.0.0
author: Orchestra Research
license: MIT
tags: [Model Architecture, NanoGPT, GPT-2, Educational, Andrej Karpathy, Transformer, Minimalist, From Scratch, Training]
dependencies: [torch, transformers, datasets, tiktoken, wandb]
---
# nanoGPT - Minimalist GPT Training
## Quick start
nanoGPT is a simplified GPT implementation designed for learning and experimentation.
**Installation**:
```bash
pip install torch numpy transformers datasets tiktoken wandb tqdm
```
**Train on Shakespeare** (CPU-friendly):
```bash
# Prepare data
python data/shakespeare_char/prepare.py
# Train (5 minutes on CPU)
python train.py config/train_shakespeare_char.py
# Generate text
python sample.py --out_dir=out-shakespeare-char
```
**Output**:
```
ROMEO:
What say'st thou? Shall I speak, and be a man?
JULIET:
I am afeard, and yet I'll speak; for thou art
One that hath been a man, and yet I know not
What thou art.
```
## Common workflows
### Workflow 1: Character-level Shakespeare
**Complete training pipeline**:
```bash
# Step 1: Prepare data (creates train.bin, val.bin)
python data/shakespeare_char/prepare.py
# Step 2: Train small model
python train.py config/train_shakespeare_char.py
# Step 3: Generate text
python sample.py --out_dir=out-shakespeare-char
```
**Config** (`config/train_shakespeare_char.py`):
```python
# Model config
n_layer = 6 # 6 transformer layers
n_head = 6 # 6 attention heads
n_embd = 384 # 384-dim embeddings
block_size = 256 # 256 char context
# Training config
batch_size = 64
learning_rate = 1e-3
max_iters = 5000
eval_interval = 500
# Hardware
device = 'cpu' # Or 'cuda'
compile = False # Set True for PyTorch 2.0
```
**Training time**: ~5 minutes (CPU), ~1 minute (GPU)
### Workflow 2: Reproduce GPT-2 (124M)
**Multi-GPU training on OpenWebText**:
```bash
# Step 1: Prepare OpenWebText (takes ~1 hour)
python data/openwebtext/prepare.py
# Step 2: Train GPT-2 124M with DDP (8 GPUs)
torchrun --standalone --nproc_per_node=8 \
train.py config/train_gpt2.py
# Step 3: Sample from trained model
python sample.py --out_dir=out
```
**Config** (`config/train_gpt2.py`):
```python
# GPT-2 (124M) architecture
n_layer = 12
n_head = 12
n_embd = 768
block_size = 1024
dropout = 0.0
# Training
batch_size = 12
gradient_accumulation_steps = 5 * 8 # Total batch ~0.5M tokens
learning_rate = 6e-4
max_iters = 600000
lr_decay_iters = 600000
# System
compile = True # PyTorch 2.0
```
**Training time**: ~4 days (8× A100)
### Workflow 3: Fine-tune pretrained GPT-2
**Start from OpenAI checkpoint**:
```python
# In train.py or config
init_from = 'gpt2' # Options: gpt2, gpt2-medium, gpt2-large, gpt2-xl
# Model loads OpenAI weights automatically
python train.py config/finetune_shakespeare.py
```
**Example config** (`config/finetune_shakespeare.py`):
```python
# Start from GPT-2
init_from = 'gpt2'
# Dataset
dataset = 'shakespeare_char'
batch_size = 1
block_size = 1024
# Fine-tuning
learning_rate = 3e-5 # Lower LR for fine-tuning
max_iters = 2000
warmup_iters = 100
# Regularization
weight_decay = 1e-1
```
### Workflow 4: Custom dataset
**Train on your own text**:
```python
# data/custom/prepare.py
import numpy as np
# Load your data
with open('my_data.txt', 'r') as f:
text = f.read()
# Create character mappings
chars = sorted(list(set(text)))
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
# Tokenize
data = np.array([stoi[ch] for ch in text], dtype=np.uint16)
# Split train/val
n = len(data)
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]
# Save
train_data.tofile('data/custom/train.bin')
val_data.tofile('data/custom/val.bin')
```
**Train**:
```bash
python data/custom/prepare.py
python train.py --dataset=custom
```
## When to use vs alternatives
**Use nanoGPT when**:
- Learning how GPT works
- Experimenting with transformer variants
- Teaching/education purposes
- Quick prototyping
- Limited compute (can run on CPU)
**Simplicity advantages**:
- **~300 lines**: Entire model in `model.py`
- **~300 lines**: Training loop in `train.py`
- **Hackable**: Easy to modify
- **No abstractions**: Pure PyTorch
**Use alternatives instead**:
- **HuggingFace Transformers**: Production use, many models
- **Megatron-LM**: Large-scale distributed training
- **LitGPT**: More architectures, production-ready
- **PyTorch Lightning**: Need high-level framework
## Common issues
**Issue: CUDA out of memory**
Reduce batch size or context length:
```python
batch_size = 1 # Reduce from 12
block_size = 512 # Reduce from 1024
gradient_accumulation_steps = 40 # Increase to maintain effective batch
```
**Issue: Training too slow**
Enable compilation (PyTorch 2.0+):
```python
compile = True # 2× speedup
```
Use mixed precision:
```python
dtype = 'bfloat16' # Or 'float16'
```
**Issue: Poor generation quality**
Train longer:
```python
max_iters = 10000 # Increase from 5000
```
Lower temperature:
```python
# In sample.py
temperature = 0.7 # Lower from 1.0
top_k = 200 # Add top-k sampling
```
**Issue: Can't load GPT-2 weights**
Install transformers:
```bash
pip install transformers
```
Check model name:
```python
init_from = 'gpt2' # Valid: gpt2, gpt2-medium, gpt2-large, gpt2-xl
```
## Advanced topics
**Model architecture**: See [references/architecture.md](references/architecture.md) for GPT block structure, multi-head attention, and MLP layers explained simply.
**Training loop**: See [references/training.md](references/training.md) for learning rate schedule, gradient accumulation, and distributed data parallel setup.
**Data preparation**: See [references/data.md](references/data.md) for tokenization strategies (character-level vs BPE) and binary format details.
## Hardware requirements
- **Shakespeare (char-level)**:
- CPU: 5 minutes
- GPU (T4): 1 minute
- VRAM: <1GB
- **GPT-2 (124M)**:
- 1× A100: ~1 week
- 8× A100: ~4 days
- VRAM: ~16GB per GPU
- **GPT-2 Medium (350M)**:
- 8× A100: ~2 weeks
- VRAM: ~40GB per GPU
**Performance**:
- With `compile=True`: 2× speedup
- With `dtype=bfloat16`: 50% memory reduction
## Resources
- GitHub: https://github.com/karpathy/nanoGPT ⭐ 48,000+
- Video: "Let's build GPT" by Andrej Karpathy
- Paper: "Attention is All You Need" (Vaswani et al.)
- OpenWebText: https://huggingface.co/datasets/Skylion007/openwebtext
- Educational: Best for understanding transformers from scratch
@@ -0,0 +1,382 @@
# NanoGPT Architecture
## Model Structure (~300 Lines)
NanoGPT implements a clean GPT-2 architecture in minimal code for educational purposes.
### Complete Model (model.py)
```python
import torch
import torch.nn as nn
from torch.nn import functional as F
class CausalSelfAttention(nn.Module):
"""Multi-head masked self-attention layer."""
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
# Key, query, value projections for all heads (batched)
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
# Output projection
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
# Regularization
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
self.n_head = config.n_head
self.n_embd = config.n_embd
self.dropout = config.dropout
# Flash attention flag
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
if not self.flash:
# Causal mask (lower triangular)
self.register_buffer("bias", torch.tril(
torch.ones(config.block_size, config.block_size)
).view(1, 1, config.block_size, config.block_size))
def forward(self, x):
B, T, C = x.size() # batch, seq_len, embedding_dim
# Calculate Q, K, V for all heads in batch
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
# Reshape for multi-head attention
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
# Attention
if self.flash:
# Flash Attention (PyTorch 2.0+)
y = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
dropout_p=self.dropout if self.training else 0,
is_causal=True
)
else:
# Manual attention implementation
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v # (B, nh, T, hs)
# Reassemble all head outputs
y = y.transpose(1, 2).contiguous().view(B, T, C)
# Output projection
y = self.resid_dropout(self.c_proj(y))
return y
class MLP(nn.Module):
"""Feedforward network (2-layer with GELU activation)."""
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
self.gelu = nn.GELU()
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
x = self.c_fc(x)
x = self.gelu(x)
x = self.c_proj(x)
x = self.dropout(x)
return x
class Block(nn.Module):
"""Transformer block (attention + MLP with residuals)."""
def __init__(self, config):
super().__init__()
self.ln_1 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.ln_2 = nn.LayerNorm(config.n_embd)
self.mlp = MLP(config)
def forward(self, x):
x = x + self.attn(self.ln_1(x)) # Pre-norm + residual
x = x + self.mlp(self.ln_2(x)) # Pre-norm + residual
return x
@dataclass
class GPTConfig:
"""GPT model configuration."""
block_size: int = 1024 # Max sequence length
vocab_size: int = 50304 # GPT-2 vocab size (50257 rounded up for efficiency)
n_layer: int = 12 # Number of layers
n_head: int = 12 # Number of attention heads
n_embd: int = 768 # Embedding dimension
dropout: float = 0.0 # Dropout rate
bias: bool = True # Use bias in Linear and LayerNorm layers
class GPT(nn.Module):
"""GPT Language Model."""
def __init__(self, config):
super().__init__()
assert config.vocab_size is not None
assert config.block_size is not None
self.config = config
self.transformer = nn.ModuleDict(dict(
wte=nn.Embedding(config.vocab_size, config.n_embd), # Token embeddings
wpe=nn.Embedding(config.block_size, config.n_embd), # Position embeddings
drop=nn.Dropout(config.dropout),
h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f=nn.LayerNorm(config.n_embd),
))
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Weight tying (share embeddings and output projection)
self.transformer.wte.weight = self.lm_head.weight
# Initialize weights
self.apply(self._init_weights)
# Apply special scaled init to residual projections
for pn, p in self.named_parameters():
if pn.endswith('c_proj.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx, targets=None):
device = idx.device
b, t = idx.size()
assert t <= self.config.block_size, f"Cannot forward sequence length {t}, max is {self.config.block_size}"
# Generate position indices
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # (1, t)
# Forward the GPT model
tok_emb = self.transformer.wte(idx) # Token embeddings (b, t, n_embd)
pos_emb = self.transformer.wpe(pos) # Position embeddings (1, t, n_embd)
x = self.transformer.drop(tok_emb + pos_emb)
for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
if targets is not None:
# Training mode: compute loss
logits = self.lm_head(x)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
else:
# Inference mode: only compute logits for last token
logits = self.lm_head(x[:, [-1], :]) # (b, 1, vocab_size)
loss = None
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
"""Generate new tokens autoregressively."""
for _ in range(max_new_tokens):
# Crop context if needed
idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
# Forward pass
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / temperature # Scale by temperature
# Optionally crop logits to top k
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
# Sample from distribution
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
# Append to sequence
idx = torch.cat((idx, idx_next), dim=1)
return idx
```
## Key Design Decisions
### 1. Pre-Norm vs Post-Norm
**NanoGPT uses Pre-Norm** (LayerNorm before sub-layers):
```python
# Pre-norm (NanoGPT)
x = x + attn(ln(x))
x = x + mlp(ln(x))
# Post-norm (original Transformer)
x = ln(x + attn(x))
x = ln(x + mlp(x))
```
**Why Pre-Norm?**
- More stable training (no gradient explosion)
- Used in GPT-2, GPT-3
- Standard for large language models
### 2. Weight Tying
**Shared weights between embeddings and output**:
```python
self.transformer.wte.weight = self.lm_head.weight
```
**Why?**
- Reduces parameters: `vocab_size × n_embd` saved
- Improves training (same semantic space)
- Standard in GPT-2
### 3. Scaled Residual Initialization
```python
# Scale down residual projections by layer depth
std = 0.02 / math.sqrt(2 * n_layer)
torch.nn.init.normal_(c_proj.weight, mean=0.0, std=std)
```
**Why?**
- Prevents gradient explosion in deep networks
- Each residual path contributes ~equally
- From GPT-2 paper
### 4. Flash Attention
```python
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
# Use PyTorch 2.0 Flash Attention (2× faster!)
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
else:
# Fallback to manual attention
att = (q @ k.T) / sqrt(d)
att = masked_fill(att, causal_mask, -inf)
y = softmax(att) @ v
```
**Speedup**: 2× faster with same accuracy
## Model Sizes
| Model | n_layer | n_head | n_embd | Params | Config Name |
|-------|---------|--------|--------|--------|-------------|
| GPT-2 Small | 12 | 12 | 768 | 124M | `gpt2` |
| GPT-2 Medium | 24 | 16 | 1024 | 350M | `gpt2-medium` |
| GPT-2 Large | 36 | 20 | 1280 | 774M | `gpt2-large` |
| GPT-2 XL | 48 | 25 | 1600 | 1558M | `gpt2-xl` |
**NanoGPT default** (Shakespeare):
```python
config = GPTConfig(
block_size=256, # Short context for char-level
vocab_size=65, # Small vocab (a-z, A-Z, punctuation)
n_layer=6, # Shallow network
n_head=6,
n_embd=384, # Small embeddings
dropout=0.2 # Regularization
)
# Total: ~10M parameters
```
## Attention Visualization
```python
# What each token attends to (lower triangular)
# Token t can only attend to tokens 0...t
Attention Pattern (causal mask):
t=0 t=1 t=2 t=3
t=0 - - -
t=1 - -
t=2 -
t=3
# Prevents "cheating" by looking at future tokens
```
## Residual Stream
**Information flow through residuals**:
```python
# Input
x = token_emb + pos_emb
# Block 1
x = x + attn_1(ln(x)) # Attention adds to residual
x = x + mlp_1(ln(x)) # MLP adds to residual
# Block 2
x = x + attn_2(ln(x))
x = x + mlp_2(ln(x))
# ... (repeat for all layers)
# Output
logits = lm_head(ln(x))
```
**Key insight**: Each layer refines the representation, residuals preserve gradients
## Tokenization
### Character-Level (Shakespeare)
```python
# data/shakespeare_char/prepare.py
text = open('input.txt', 'r').read()
chars = sorted(list(set(text))) # ['!', ',', '.', 'A', 'B', ..., 'z']
vocab_size = len(chars) # 65
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
# Encode
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])
data = torch.tensor(encode(text), dtype=torch.long)
```
### BPE (GPT-2)
```python
# data/openwebtext/prepare.py
import tiktoken
enc = tiktoken.get_encoding("gpt2") # GPT-2 BPE tokenizer
vocab_size = enc.n_vocab # 50257
# Encode
tokens = enc.encode_ordinary("Hello world") # [15496, 995]
# Decode
text = enc.decode(tokens) # "Hello world"
```
## Resources
- **GitHub**: https://github.com/karpathy/nanoGPT ⭐ 48,000+
- **Video**: "Let's build GPT" by Andrej Karpathy
- **Paper**: "Attention is All You Need" (Vaswani et al.)
- **Paper**: "Language Models are Unsupervised Multitask Learners" (GPT-2)
- **Code walkthrough**: https://github.com/karpathy/nanoGPT/blob/master/ARCHITECTURE.md
@@ -0,0 +1,476 @@
# NanoGPT Data Preparation
## Data Format
NanoGPT uses **binary token files** for efficient loading:
```
dataset/
├── train.bin # Training tokens (uint16 array)
├── val.bin # Validation tokens (uint16 array)
└── meta.pkl # Metadata (vocab_size, mappings)
```
**Why binary?**
- 100× faster than reading text files
- Memory-mapped loading (no RAM overhead)
- Simple format (just token IDs)
## Character-Level Tokenization
### Shakespeare Example
**Input text**:
```
First Citizen:
Before we proceed any further, hear me speak.
All:
Speak, speak.
```
**Character vocabulary** (65 total):
```python
chars = ['\n', ' ', '!', ',', '.', ':', ';', '?', 'A', 'B', ..., 'z']
stoi = {'\n': 0, ' ': 1, '!': 2, ...} # char → ID
itos = {0: '\n', 1: ' ', 2: '!', ...} # ID → char
```
**Tokenization**:
```python
text = "First Citizen:"
tokens = [18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 63, 43, 52, 10]
# F=18, i=47, r=56, s=57, t=58, ' '=1, C=15, ...
```
**Full preparation script**:
```python
# data/shakespeare_char/prepare.py
import os
import requests
import pickle
import numpy as np
# Download Shakespeare dataset
input_file = 'input.txt'
if not os.path.exists(input_file):
url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
with open(input_file, 'w') as f:
f.write(requests.get(url).text)
# Load text
with open(input_file, 'r') as f:
data = f.read()
print(f"Dataset size: {len(data):,} characters")
# Build vocabulary
chars = sorted(list(set(data)))
vocab_size = len(chars)
print(f"Vocabulary: {vocab_size} unique characters")
print(f"Characters: {''.join(chars[:20])}...")
# Create mappings
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
# Encode full dataset
def encode(s):
return [stoi[c] for c in s]
def decode(l):
return ''.join([itos[i] for i in l])
# Split train/val (90/10)
n = len(data)
train_data = data[:int(n * 0.9)]
val_data = data[int(n * 0.9):]
# Tokenize
train_ids = encode(train_data)
val_ids = encode(val_data)
print(f"Train: {len(train_ids):,} tokens")
print(f"Val: {len(val_ids):,} tokens")
# Save as binary (uint16)
train_ids = np.array(train_ids, dtype=np.uint16)
val_ids = np.array(val_ids, dtype=np.uint16)
train_ids.tofile('train.bin')
val_ids.tofile('val.bin')
# Save metadata
meta = {
'vocab_size': vocab_size,
'itos': itos,
'stoi': stoi,
}
with open('meta.pkl', 'wb') as f:
pickle.dump(meta, f)
print("Saved train.bin, val.bin, meta.pkl")
```
**Output**:
```
Dataset size: 1,115,394 characters
Vocabulary: 65 unique characters
Characters: !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
Train: 1,003,854 tokens
Val: 111,540 tokens
Saved train.bin, val.bin, meta.pkl
```
### Custom Character Dataset
```python
# For your own text dataset
text = open('my_data.txt', 'r').read()
# Build vocab
chars = sorted(list(set(text)))
vocab_size = len(chars)
# Create mappings
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
# Encode
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])
# Split and save
data = np.array(encode(text), dtype=np.uint16)
n = len(data)
train = data[:int(n*0.9)]
val = data[int(n*0.9):]
train.tofile('data/custom/train.bin')
val.tofile('data/custom/val.bin')
# Save meta
with open('data/custom/meta.pkl', 'wb') as f:
pickle.dump({'vocab_size': vocab_size, 'itos': itos, 'stoi': stoi}, f)
```
## BPE (Byte Pair Encoding)
### OpenWebText with GPT-2 Tokenizer
**BPE advantages**:
- Handles rare words better (subword units)
- Standard for GPT-2, GPT-3
- Vocabulary: 50,257 tokens
**Preparation script**:
```python
# data/openwebtext/prepare.py
import os
import numpy as np
import tiktoken
from datasets import load_dataset
from tqdm import tqdm
# Number of workers for parallel processing
num_proc = 8
num_proc_load_dataset = num_proc
# Download OpenWebText dataset
dataset = load_dataset("openwebtext", num_proc=num_proc_load_dataset)
# Use GPT-2 tokenizer
enc = tiktoken.get_encoding("gpt2")
def process(example):
"""Tokenize a single example."""
ids = enc.encode_ordinary(example['text']) # Tokenize
ids.append(enc.eot_token) # Add end-of-text token
out = {'ids': ids, 'len': len(ids)}
return out
# Tokenize entire dataset (parallel)
tokenized = dataset.map(
process,
remove_columns=['text'],
desc="Tokenizing",
num_proc=num_proc,
)
# Concatenate all into one big array
train_ids = np.concatenate([
np.array(sample['ids'], dtype=np.uint16)
for sample in tqdm(tokenized['train'], desc="Concatenating")
])
print(f"Total tokens: {len(train_ids):,}") # ~9 billion tokens
# Save train.bin
train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin'))
# Create val.bin (sample from train)
# Take first 5000 documents for validation
val_ids = np.concatenate([
np.array(sample['ids'], dtype=np.uint16)
for sample in tokenized['train'][:5000]
])
val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin'))
# Save metadata
import pickle
meta = {
'vocab_size': enc.n_vocab,
'eot_token': enc.eot_token,
}
with open(os.path.join(os.path.dirname(__file__), 'meta.pkl'), 'wb') as f:
pickle.dump(meta, f)
print(f"Train tokens: {len(train_ids):,}")
print(f"Val tokens: {len(val_ids):,}")
print(f"Vocab size: {enc.n_vocab:,}")
```
**Output**:
```
Total tokens: 9,035,582,198
Train tokens: 9,035,582,198
Val tokens: 4,123,676
Vocab size: 50,257
```
**Time**: 1-2 hours on 8-core CPU
**Disk usage**:
- train.bin: ~18 GB (9B tokens × 2 bytes)
- val.bin: ~8 MB
- Original text: ~54 GB
### BPE Tokenization Example
```python
import tiktoken
enc = tiktoken.get_encoding("gpt2")
# Tokenize
text = "Hello world! This is a test."
tokens = enc.encode_ordinary(text)
print(tokens)
# [15496, 995, 0, 770, 318, 257, 1332, 13]
# Decode
decoded = enc.decode(tokens)
print(decoded)
# "Hello world! This is a test."
# Token → text
print([enc.decode([t]) for t in tokens])
# ['Hello', ' world', '!', ' This', ' is', ' a', ' test', '.']
```
**Subword splitting**:
```python
# Rare word "electroencephalography" is split
tokens = enc.encode_ordinary("electroencephalography")
print([enc.decode([t]) for t in tokens])
# ['elect', 'ro', 'ence', 'ph', 'al', 'ography']
```
## Data Loading
### Memory-Mapped Loading (Efficient)
```python
import numpy as np
import torch
# Load data (memory-mapped, no RAM overhead)
data_dir = 'data/shakespeare_char'
train_data = np.memmap(
os.path.join(data_dir, 'train.bin'),
dtype=np.uint16,
mode='r'
)
print(f"Loaded {len(train_data):,} tokens") # No actual read yet!
# Get batch (read on-demand)
def get_batch(split):
data = train_data if split == 'train' else val_data
# Random indices
ix = torch.randint(len(data) - block_size, (batch_size,))
# Extract sequences
x = torch.stack([torch.from_numpy(data[i:i+block_size].astype(np.int64)) for i in ix])
y = torch.stack([torch.from_numpy(data[i+1:i+1+block_size].astype(np.int64)) for i in ix])
# Move to GPU
x, y = x.to('cuda'), y.to('cuda')
return x, y
# Usage
X, Y = get_batch('train')
# X shape: (batch_size, block_size)
# Y shape: (batch_size, block_size)
```
**Memory efficiency**:
- 9 GB dataset loaded with ~0 MB RAM
- Only batch data is loaded into memory
### Data Loader (PyTorch)
```python
from torch.utils.data import Dataset, DataLoader
class TokenDataset(Dataset):
def __init__(self, data_path, block_size):
self.data = np.memmap(data_path, dtype=np.uint16, mode='r')
self.block_size = block_size
def __len__(self):
return len(self.data) - self.block_size
def __getitem__(self, idx):
x = torch.from_numpy(self.data[idx:idx+self.block_size].astype(np.int64))
y = torch.from_numpy(self.data[idx+1:idx+1+self.block_size].astype(np.int64))
return x, y
# Create data loader
train_dataset = TokenDataset('data/shakespeare_char/train.bin', block_size=256)
train_loader = DataLoader(
train_dataset,
batch_size=64,
shuffle=True,
num_workers=4,
pin_memory=True
)
# Usage
for X, Y in train_loader:
X, Y = X.to('cuda'), Y.to('cuda')
# Train...
```
## Custom Datasets
### Wikipedia
```python
from datasets import load_dataset
# Load Wikipedia
dataset = load_dataset("wikipedia", "20220301.en", num_proc=8)
# Tokenize
enc = tiktoken.get_encoding("gpt2")
def tokenize(example):
ids = enc.encode_ordinary(example['text'])
return {'ids': ids, 'len': len(ids)}
tokenized = dataset.map(tokenize, num_proc=8, remove_columns=['text', 'title'])
# Save
train_ids = np.concatenate([np.array(x['ids'], dtype=np.uint16) for x in tokenized['train']])
train_ids.tofile('data/wikipedia/train.bin')
```
### Code (GitHub)
```python
from datasets import load_dataset
# Load code dataset (The Stack)
dataset = load_dataset("bigcode/the-stack", data_dir="data/python", num_proc=8)
# Tokenize (same as above)
enc = tiktoken.get_encoding("gpt2")
# ... tokenize and save
```
### Custom Text Files
```python
# Load custom text files
import glob
files = glob.glob('my_dataset/*.txt')
text = ''
for file in files:
with open(file, 'r') as f:
text += f.read() + '\n'
# Character-level
chars = sorted(list(set(text)))
stoi = {ch: i for i, ch in enumerate(chars)}
data = np.array([stoi[c] for c in text], dtype=np.uint16)
# Split and save
n = len(data)
train = data[:int(n*0.9)]
val = data[int(n*0.9):]
train.tofile('data/custom/train.bin')
val.tofile('data/custom/val.bin')
# Meta
with open('data/custom/meta.pkl', 'wb') as f:
pickle.dump({'vocab_size': len(chars), 'itos': {i: ch for i, ch in enumerate(chars)}, 'stoi': stoi}, f)
```
## Data Augmentation (Advanced)
### Random Masking (BERT-style)
```python
def random_mask(tokens, mask_prob=0.15):
"""Randomly mask tokens for denoising objective."""
mask = torch.rand(tokens.shape) < mask_prob
tokens[mask] = mask_token_id
return tokens
# Usage in training
X, Y = get_batch('train')
X_masked = random_mask(X.clone())
logits, loss = model(X_masked, Y) # Predict original from masked
```
### Document Shuffling
```python
# Shuffle document order (not token order)
# Better generalization than sequential documents
import random
# Load documents
docs = dataset['train']
random.shuffle(docs)
# Concatenate shuffled
train_ids = np.concatenate([np.array(doc['ids'], dtype=np.uint16) for doc in docs])
```
## Benchmarks
| Dataset | Tokens | Vocab | Prep Time | Disk Size |
|---------|--------|-------|-----------|-----------|
| Shakespeare (char) | 1M | 65 | 1 sec | 2 MB |
| TinyStories | 250M | 50K | 5 min | 500 MB |
| OpenWebText | 9B | 50K | 90 min | 18 GB |
| The Pile | 300B | 50K | ~2 days | 600 GB |
## Resources
- Data preparation scripts: https://github.com/karpathy/nanoGPT/tree/master/data
- Tiktoken (BPE tokenizer): https://github.com/openai/tiktoken
- HuggingFace datasets: https://huggingface.co/datasets
- OpenWebText: https://huggingface.co/datasets/Skylion007/openwebtext
- The Stack (code): https://huggingface.co/datasets/bigcode/the-stack
@@ -0,0 +1,564 @@
# NanoGPT Training Guide
## Training Loop (~300 Lines)
NanoGPT's `train.py` is a self-contained training script with minimal dependencies.
### Complete Training Script Structure
```python
# train.py (simplified)
import os
import time
import math
import pickle
import torch
from model import GPTConfig, GPT
# Training config
batch_size = 12 # Micro batch size
block_size = 1024 # Context length
gradient_accumulation_steps = 5 * 8 # ~60K tokens per batch
# Model config
n_layer = 12
n_head = 12
n_embd = 768
dropout = 0.0
# Optimizer config
learning_rate = 6e-4
max_iters = 600000
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95
grad_clip = 1.0
# Learning rate schedule
warmup_iters = 2000
lr_decay_iters = 600000
min_lr = 6e-5
# System
device = 'cuda'
dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float16'
compile = True # PyTorch 2.0
# Data loader
def get_batch(split):
data = train_data if split == 'train' else val_data
ix = torch.randint(len(data) - block_size, (batch_size,))
x = torch.stack([data[i:i+block_size] for i in ix])
y = torch.stack([data[i+1:i+1+block_size] for i in ix])
x, y = x.to(device), y.to(device)
return x, y
# Learning rate schedule
def get_lr(it):
# Warmup
if it < warmup_iters:
return learning_rate * it / warmup_iters
# Decay to min_lr
if it > lr_decay_iters:
return min_lr
# Cosine decay
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
return min_lr + coeff * (learning_rate - min_lr)
# Init model
model = GPT(GPTConfig())
model.to(device)
# Compile model (PyTorch 2.0)
if compile:
print("Compiling model...")
model = torch.compile(model)
# Optimizer
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device)
# Training loop
for iter_num in range(max_iters):
# Set learning rate
lr = get_lr(iter_num)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# Gradient accumulation
for micro_step in range(gradient_accumulation_steps):
X, Y = get_batch('train')
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
logits, loss = model(X, Y)
loss = loss / gradient_accumulation_steps
loss.backward()
# Clip gradients
if grad_clip != 0.0:
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
# Update weights
optimizer.step()
optimizer.zero_grad(set_to_none=True)
# Logging
if iter_num % 100 == 0:
print(f"iter {iter_num}: loss {loss.item():.4f}, lr {lr:.2e}")
```
## Data Preparation
### Shakespeare Character-Level
```bash
# Step 1: Download Shakespeare
cd data/shakespeare_char
python prepare.py
# Creates:
# - train.bin (90% of data, ~1MB)
# - val.bin (10% of data, ~110KB)
# - meta.pkl (vocab info)
```
**prepare.py**:
```python
import os
import pickle
import requests
import numpy as np
# Download
input_file = 'input.txt'
if not os.path.exists(input_file):
url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
with open(input_file, 'w') as f:
f.write(requests.get(url).text)
# Read and process
with open(input_file, 'r') as f:
data = f.read()
print(f"Length: {len(data):,} characters")
# Create vocabulary
chars = sorted(list(set(data)))
vocab_size = len(chars)
print(f"Vocab size: {vocab_size}")
# Create mappings
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
# Encode dataset
data_ids = [stoi[c] for c in data]
# Train/val split
n = len(data_ids)
train_ids = data_ids[:int(n*0.9)]
val_ids = data_ids[int(n*0.9):]
# Save as numpy arrays
train_ids = np.array(train_ids, dtype=np.uint16)
val_ids = np.array(val_ids, dtype=np.uint16)
train_ids.tofile('train.bin')
val_ids.tofile('val.bin')
# Save metadata
meta = {'vocab_size': vocab_size, 'itos': itos, 'stoi': stoi}
with open('meta.pkl', 'wb') as f:
pickle.dump(meta, f)
```
### OpenWebText (GPT-2 Reproduction)
```bash
# Step 1: Download OpenWebText (~12GB compressed)
cd data/openwebtext
python prepare.py
# Warning: Takes 1-2 hours, creates ~54GB of tokenized data
```
**prepare.py**:
```python
import os
import numpy as np
import tiktoken
from datasets import load_dataset
# Download dataset
dataset = load_dataset("openwebtext", num_proc=8)
# Use GPT-2 tokenizer
enc = tiktoken.get_encoding("gpt2")
def tokenize(example):
ids = enc.encode_ordinary(example['text'])
ids.append(enc.eot_token) # Add <|endoftext|>
return {'ids': ids, 'len': len(ids)}
# Tokenize (parallel)
tokenized = dataset.map(
tokenize,
remove_columns=['text'],
desc="Tokenizing",
num_proc=8
)
# Concatenate all tokens
train_ids = np.concatenate([np.array(x['ids'], dtype=np.uint16) for x in tokenized['train']])
print(f"Train tokens: {len(train_ids):,}") # ~9B tokens
# Save
train_ids.tofile('train.bin')
# Validation set (sample)
val_ids = np.concatenate([np.array(x['ids'], dtype=np.uint16) for x in tokenized['train'][:5000]])
val_ids.tofile('val.bin')
# Save metadata
meta = {'vocab_size': enc.n_vocab, 'eot_token': enc.eot_token}
with open('meta.pkl', 'wb') as f:
pickle.dump(meta, f)
```
## Learning Rate Schedules
### Cosine Decay with Warmup (GPT-2 style)
```python
def get_lr(it):
# 1) Linear warmup
if it < warmup_iters:
return learning_rate * it / warmup_iters
# 2) Constant at min_lr after decay
if it > lr_decay_iters:
return min_lr
# 3) Cosine decay in between
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
return min_lr + coeff * (learning_rate - min_lr)
# Example values
learning_rate = 6e-4 # Peak LR
min_lr = 6e-5 # Final LR (10% of peak)
warmup_iters = 2000 # Warmup steps
lr_decay_iters = 600000 # Total training steps
```
**Visualization**:
```
LR
^
| Peak (6e-4)
| /‾‾‾‾‾‾‾‾‾‾\
| / \
| / \_____ Min (6e-5)
| /
|/________________> Iteration
Warmup Cosine Const
(2K) (598K)
```
### Constant LR with Warmup (Simple)
```python
def get_lr(it):
if it < warmup_iters:
return learning_rate * it / warmup_iters
return learning_rate
# Good for small experiments
```
## Gradient Accumulation
**Effective batch size** = `batch_size × gradient_accumulation_steps × num_gpus`
```python
# Config
batch_size = 12 # Per-GPU micro batch
gradient_accumulation_steps = 40 # Accumulate gradients
# Effective batch: 12 × 40 = 480 sequences = ~0.5M tokens
# Training loop
optimizer.zero_grad()
for micro_step in range(gradient_accumulation_steps):
X, Y = get_batch('train')
logits, loss = model(X, Y)
loss = loss / gradient_accumulation_steps # Scale loss
loss.backward() # Accumulate gradients
# Update once
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
optimizer.step()
```
**Why?**
- Simulates large batch size without OOM
- GPT-2 (124M) uses effective batch ~0.5M tokens
- More stable training
## Mixed Precision Training
### BF16 (Best for A100/H100)
```python
# Enable bfloat16
dtype = torch.bfloat16
# Training loop
for iter in range(max_iters):
X, Y = get_batch('train')
# Forward in BF16
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
logits, loss = model(X, Y)
# Backward in FP32 (automatic)
loss.backward()
optimizer.step()
```
**Advantages**:
- No gradient scaler needed
- Same dynamic range as FP32
- 2× faster, 50% memory reduction
### FP16 (V100, older GPUs)
```python
from torch.cuda.amp import GradScaler, autocast
scaler = GradScaler()
for iter in range(max_iters):
X, Y = get_batch('train')
# Forward in FP16
with autocast():
logits, loss = model(X, Y)
# Scale loss, backward
scaler.scale(loss).backward()
# Unscale, clip gradients
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
# Update weights
scaler.step(optimizer)
scaler.update()
```
## Distributed Data Parallel (DDP)
### Single Node, Multiple GPUs
```python
# train.py (DDP version)
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
# Initialize
dist.init_process_group(backend='nccl')
ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ['LOCAL_RANK'])
ddp_world_size = int(os.environ['WORLD_SIZE'])
device = f'cuda:{ddp_local_rank}'
torch.cuda.set_device(device)
# Model
model = GPT(GPTConfig())
model.to(device)
model = DDP(model, device_ids=[ddp_local_rank])
# Training loop (same as before, DDP handles gradient sync)
for iter in range(max_iters):
X, Y = get_batch('train') # Each rank gets different data
logits, loss = model(X, Y)
loss.backward() # DDP syncs gradients across GPUs
optimizer.step()
```
**Launch**:
```bash
# 8 GPUs on single node
torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py
```
### Multi-Node Training
```bash
# Node 0 (master)
torchrun --nproc_per_node=8 \
--nnodes=4 --node_rank=0 \
--master_addr=192.168.1.100 --master_port=29500 \
train.py config/train_gpt2.py
# Node 1-3 (workers)
torchrun --nproc_per_node=8 \
--nnodes=4 --node_rank=$RANK \
--master_addr=192.168.1.100 --master_port=29500 \
train.py config/train_gpt2.py
```
## Checkpointing
### Save Checkpoint
```python
# Save every N iterations
if iter_num % 5000 == 0:
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'model_args': model_args,
'iter_num': iter_num,
'best_val_loss': best_val_loss,
'config': config,
}
torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{iter_num}.pt'))
```
### Resume from Checkpoint
```python
# Load checkpoint
init_from = 'resume' # or 'gpt2', 'gpt2-medium', etc.
if init_from == 'resume':
ckpt_path = os.path.join(out_dir, 'ckpt_latest.pt')
checkpoint = torch.load(ckpt_path, map_location=device)
# Restore model
model_args = checkpoint['model_args']
model = GPT(GPTConfig(**model_args))
model.load_state_dict(checkpoint['model'])
# Restore optimizer
optimizer.load_state_dict(checkpoint['optimizer'])
# Restore iteration counter
iter_num = checkpoint['iter_num']
best_val_loss = checkpoint['best_val_loss']
```
## Fine-Tuning Pretrained Models
### Load OpenAI GPT-2 Weights
```python
# model.py - from_pretrained method
@classmethod
def from_pretrained(cls, model_type):
"""Load pretrained GPT-2 model weights from HuggingFace."""
from transformers import GPT2LMHeadModel
# Download from HuggingFace
model_hf = GPT2LMHeadModel.from_pretrained(model_type)
sd_hf = model_hf.state_dict()
# Filter out keys we don't need
sd_hf_keys = [k for k in sd_hf.keys() if not k.endswith('.attn.masked_bias')]
sd_hf_keys = [k for k in sd_hf_keys if not k.endswith('.attn.bias')]
# Create our model
config = GPTConfig.from_model_type(model_type)
model = GPT(config)
sd = model.state_dict()
# Copy weights (transpose Conv1D → Linear)
for k in sd_hf_keys:
if any([k.endswith(w) for w in ['.c_attn.weight', '.c_proj.weight', '.c_fc.weight']]):
sd[k] = sd_hf[k].t() # Transpose
else:
sd[k] = sd_hf[k] # Direct copy
model.load_state_dict(sd)
return model
# Usage
model = GPT.from_pretrained('gpt2') # Load GPT-2 (124M)
```
### Fine-Tune on Custom Data
```python
# config/finetune_shakespeare.py
init_from = 'gpt2' # Start from GPT-2
dataset = 'shakespeare_char'
# Fine-tuning hyperparameters
learning_rate = 3e-5 # Lower LR for fine-tuning
max_iters = 2000 # Short fine-tuning
warmup_iters = 100
# Regularization
weight_decay = 1e-1
dropout = 0.2 # Add dropout
# Run
# python train.py config/finetune_shakespeare.py
```
## Evaluation
### Perplexity
```python
@torch.no_grad()
def estimate_loss():
model.eval()
losses = torch.zeros(eval_iters)
for k in range(eval_iters):
X, Y = get_batch('val')
logits, loss = model(X, Y)
losses[k] = loss.item()
model.train()
return losses.mean()
# Usage
val_loss = estimate_loss()
perplexity = math.exp(val_loss)
print(f"Val perplexity: {perplexity:.2f}")
```
### Sample Generation
```python
# sample.py
model.eval()
start = "ROMEO:" # Prompt
start_ids = encode(start)
x = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]
# Generate
with torch.no_grad():
y = model.generate(x, max_new_tokens=500, temperature=0.8, top_k=200)
print(decode(y[0].tolist()))
```
## Training Times
| Setup | Model | Hardware | Batch Size | Time to Perplexity 10 |
|-------|-------|----------|------------|----------------------|
| Shakespeare | 10M | 1× CPU | 64 | 5 minutes |
| Shakespeare | 10M | 1× T4 GPU | 64 | 1 minute |
| OpenWebText | 124M | 1× A100 | 480 | 7 days |
| OpenWebText | 124M | 8× A100 | 3840 | 4 days |
| OpenWebText | 350M | 8× A100 | 1920 | 14 days |
## Resources
- Training script: https://github.com/karpathy/nanoGPT/blob/master/train.py
- Configs: https://github.com/karpathy/nanoGPT/tree/master/config
- Video walkthrough: "Let's build GPT" (training section)
- GPT-2 paper: https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf
@@ -0,0 +1,260 @@
---
name: rwkv-architecture
description: RNN+Transformer hybrid with O(n) inference. Linear time, infinite context, no KV cache. Train like GPT (parallel), infer like RNN (sequential). Linux Foundation AI project. Production at Windows, Office, NeMo. RWKV-7 (March 2025). Models up to 14B parameters.
version: 1.0.0
author: Orchestra Research
license: MIT
tags: [RWKV, Model Architecture, RNN, Transformer Hybrid, Linear Complexity, Infinite Context, Efficient Inference, Linux Foundation, Alternative Architecture]
dependencies: [rwkv, torch, transformers]
---
# RWKV - Receptance Weighted Key Value
## Quick start
RWKV (RwaKuv) combines Transformer parallelization (training) with RNN efficiency (inference).
**Installation**:
```bash
# Install PyTorch
pip install torch --upgrade --extra-index-url https://download.pytorch.org/whl/cu121
# Install dependencies
pip install pytorch-lightning==1.9.5 deepspeed wandb ninja --upgrade
# Install RWKV
pip install rwkv
```
**Basic usage** (GPT mode + RNN mode):
```python
import os
from rwkv.model import RWKV
os.environ["RWKV_JIT_ON"] = '1'
os.environ["RWKV_CUDA_ON"] = '1' # Use CUDA kernel for speed
# Load model
model = RWKV(
model='/path/to/RWKV-4-Pile-1B5-20220903-8040',
strategy='cuda fp16'
)
# GPT mode (parallel processing)
out, state = model.forward([187, 510, 1563, 310, 247], None)
print(out.detach().cpu().numpy()) # Logits
# RNN mode (sequential processing, same result)
out, state = model.forward([187, 510], None) # First 2 tokens
out, state = model.forward([1563], state) # Next token
out, state = model.forward([310, 247], state) # Last tokens
print(out.detach().cpu().numpy()) # Same logits as above!
```
## Common workflows
### Workflow 1: Text generation (streaming)
**Efficient token-by-token generation**:
```python
from rwkv.model import RWKV
from rwkv.utils import PIPELINE
model = RWKV(model='RWKV-4-Pile-14B-20230313-ctx8192-test1050', strategy='cuda fp16')
pipeline = PIPELINE(model, "20B_tokenizer.json")
# Initial prompt
prompt = "The future of AI is"
state = None
# Generate token by token
for token in prompt:
out, state = pipeline.model.forward(pipeline.encode(token), state)
# Continue generation
for _ in range(100):
out, state = pipeline.model.forward(None, state)
token = pipeline.sample_logits(out)
print(pipeline.decode(token), end='', flush=True)
```
**Key advantage**: Constant memory per token (no growing KV cache)
### Workflow 2: Long context processing (infinite context)
**Process million-token sequences**:
```python
model = RWKV(model='RWKV-4-Pile-14B', strategy='cuda fp16')
# Process very long document
state = None
long_document = load_document() # e.g., 1M tokens
# Stream through entire document
for chunk in chunks(long_document, chunk_size=1024):
out, state = model.forward(chunk, state)
# State now contains information from entire 1M token document
# Memory usage: O(1) (constant, not O(n)!)
```
### Workflow 3: Fine-tuning RWKV
**Standard fine-tuning workflow**:
```python
# Training script
import pytorch_lightning as pl
from rwkv.model import RWKV
from rwkv.trainer import RWKVTrainer
# Configure model
config = {
'n_layer': 24,
'n_embd': 1024,
'vocab_size': 50277,
'ctx_len': 1024
}
# Setup trainer
trainer = pl.Trainer(
accelerator='gpu',
devices=8,
precision='bf16',
strategy='deepspeed_stage_2',
max_epochs=1
)
# Train
model = RWKV(config)
trainer.fit(model, train_dataloader)
```
### Workflow 4: RWKV vs Transformer comparison
**Memory comparison** (1M token sequence):
```python
# Transformer (GPT)
# Memory: O(n²) for attention
# KV cache: 1M × hidden_dim × n_layers × 2 (keys + values)
# Example: 1M × 4096 × 24 × 2 = ~400GB (impractical!)
# RWKV
# Memory: O(1) per token
# State: hidden_dim × n_layers = 4096 × 24 = ~400KB
# 1,000,000× more efficient!
```
**Speed comparison** (inference):
```python
# Transformer: O(n) per token (quadratic overall)
# First token: 1 computation
# Second token: 2 computations
# ...
# 1000th token: 1000 computations
# RWKV: O(1) per token (linear overall)
# Every token: 1 computation
# 1000th token: 1 computation (same as first!)
```
## When to use vs alternatives
**Use RWKV when**:
- Need very long context (100K+ tokens)
- Want constant memory usage
- Building streaming applications
- Need RNN efficiency with Transformer performance
- Memory-constrained deployment
**Key advantages**:
- **Linear time**: O(n) vs O(n²) for Transformers
- **No KV cache**: Constant memory per token
- **Infinite context**: No fixed window limit
- **Parallelizable training**: Like GPT
- **Sequential inference**: Like RNN
**Use alternatives instead**:
- **Transformers**: Need absolute best performance, have compute
- **Mamba**: Want state-space models
- **RetNet**: Need retention mechanism
- **Hyena**: Want convolution-based approach
## Common issues
**Issue: Out of memory during training**
Use gradient checkpointing and DeepSpeed:
```python
trainer = pl.Trainer(
strategy='deepspeed_stage_3', # Full ZeRO-3
precision='bf16'
)
```
**Issue: Slow inference**
Enable CUDA kernel:
```python
os.environ["RWKV_CUDA_ON"] = '1'
```
**Issue: Model not loading**
Check model path and strategy:
```python
model = RWKV(
model='/absolute/path/to/model.pth',
strategy='cuda fp16' # Or 'cpu fp32' for CPU
)
```
**Issue: State management in RNN mode**
Always pass state between forward calls:
```python
# WRONG: State lost
out1, _ = model.forward(tokens1, None)
out2, _ = model.forward(tokens2, None) # No context from tokens1!
# CORRECT: State preserved
out1, state = model.forward(tokens1, None)
out2, state = model.forward(tokens2, state) # Has context from tokens1
```
## Advanced topics
**Time-mixing and channel-mixing**: See [references/architecture-details.md](references/architecture-details.md) for WKV operation, time-decay mechanism, and receptance gates.
**State management**: See [references/state-management.md](references/state-management.md) for att_x_prev, att_kv, ffn_x_prev states, and numerical stability considerations.
**RWKV-7 improvements**: See [references/rwkv7.md](references/rwkv7.md) for latest architectural improvements (March 2025) and multimodal capabilities.
## Hardware requirements
- **GPU**: NVIDIA (CUDA 11.6+) or CPU
- **VRAM** (FP16):
- 169M model: 1GB
- 430M model: 2GB
- 1.5B model: 4GB
- 3B model: 8GB
- 7B model: 16GB
- 14B model: 32GB
- **Inference**: O(1) memory per token
- **Training**: Parallelizable like GPT
**Performance** (vs Transformers):
- **Speed**: Similar training, faster inference
- **Memory**: 1000× less for long sequences
- **Scaling**: Linear vs quadratic
## Resources
- Paper (RWKV): https://arxiv.org/abs/2305.13048 (May 2023)
- Paper (RWKV-7): https://arxiv.org/abs/2503.14456 (March 2025)
- GitHub: https://github.com/BlinkDL/RWKV-LM ⭐ 12,000+
- Docs: https://wiki.rwkv.com/
- Models: https://huggingface.co/BlinkDL
- Linux Foundation AI: Official project
- Production: Microsoft Windows, Office integration, NeMo support
@@ -0,0 +1,344 @@
# RWKV Architecture Details
## Time-Mixing and Channel-Mixing Blocks
RWKV alternates between **Time-Mixing** (sequence processing) and **Channel-Mixing** (feature processing) blocks.
### Time-Mixing Block (WKV Operation)
The core innovation is the **WKV (Weighted Key-Value)** mechanism:
```python
# Traditional Attention (O(n²))
scores = Q @ K.T / sqrt(d) # n×n matrix
attention = softmax(scores)
output = attention @ V
# RWKV Time-Mixing (O(n))
# Compute WKV in linear time using recurrence
for t in range(T):
wkv[t] = (exp(w) * k[t] @ v[t] + a[t] * aa[t]) / (exp(w) * k[t] + a[t] * ab[t])
aa[t+1] = exp(w) * k[t] @ v[t] + exp(-u) * aa[t]
ab[t+1] = exp(w) * k[t] + exp(-u) * ab[t]
```
**Full Time-Mixing implementation**:
```python
class RWKV_TimeMix(nn.Module):
def __init__(self, d_model, n_layer):
super().__init__()
self.d_model = d_model
# Linear projections
self.key = nn.Linear(d_model, d_model, bias=False)
self.value = nn.Linear(d_model, d_model, bias=False)
self.receptance = nn.Linear(d_model, d_model, bias=False)
self.output = nn.Linear(d_model, d_model, bias=False)
# Time-mixing parameters
self.time_mix_k = nn.Parameter(torch.ones(1, 1, d_model))
self.time_mix_v = nn.Parameter(torch.ones(1, 1, d_model))
self.time_mix_r = nn.Parameter(torch.ones(1, 1, d_model))
# Time-decay and bonus
self.time_decay = nn.Parameter(torch.ones(d_model)) # w
self.time_first = nn.Parameter(torch.ones(d_model)) # u
def forward(self, x, state=None):
B, T, C = x.shape
# Time-shift mixing (interpolate with previous token)
if state is None:
state = torch.zeros(B, C, 3, device=x.device) # [aa, ab, x_prev]
x_prev = state[:, :, 2].unsqueeze(1) # Previous x
xk = x * self.time_mix_k + x_prev * (1 - self.time_mix_k)
xv = x * self.time_mix_v + x_prev * (1 - self.time_mix_v)
xr = x * self.time_mix_r + x_prev * (1 - self.time_mix_r)
# Compute k, v, r
k = self.key(xk)
v = self.value(xv)
r = self.receptance(xr)
# WKV computation (parallelizable or sequential)
wkv = self.wkv(k, v, state[:, :, :2])
# Apply receptance gate and output projection
out = self.output(torch.sigmoid(r) * wkv)
# Update state
new_state = torch.stack([state_aa, state_ab, x[:, -1]], dim=2)
return out, new_state
def wkv(self, k, v, state):
# Parallel implementation (training)
# Sequential implementation (inference) - see below
...
```
### WKV Parallel Algorithm (Training)
```python
def wkv_forward(w, u, k, v):
"""
Parallel WKV computation for training.
w: time_decay (d_model,)
u: time_first (d_model,)
k: keys (batch, seq_len, d_model)
v: values (batch, seq_len, d_model)
"""
B, T, C = k.shape
# Compute cumulative sums with exponential decay
# This is the key to O(n) parallel computation
w = -torch.exp(w) # Negative for decay
# Associative scan operation
wkv = torch.zeros(B, T, C, device=k.device)
state = torch.zeros(B, C, device=k.device)
for t in range(T):
kv = k[:, t] * v[:, t]
wkv[:, t] = (u * kv + state) / (u * k[:, t] + torch.exp(state_count))
state = w * state + kv
return wkv
```
### WKV Sequential Algorithm (Inference)
```python
def wkv_inference(w, u, k, v, state):
"""
Sequential WKV for O(1) per-token inference.
state: (aa, ab) from previous step
"""
w = -torch.exp(w) # time_decay
u = torch.exp(u) # time_first
# Unpack state
aa, ab = state # aa = numerator, ab = denominator
# Compute WKV for current token
kv = k * v
wkv = (u * kv + aa) / (u * k + ab)
# Update state for next token
new_aa = w * aa + kv
new_ab = w * ab + k
return wkv, (new_aa, new_ab)
```
### Channel-Mixing Block
Replaces Transformer FFN with time-shifted variant:
```python
class RWKV_ChannelMix(nn.Module):
def __init__(self, d_model, hidden_ratio=4):
super().__init__()
self.d_model = d_model
self.hidden = d_model * hidden_ratio
# Time-mixing for channel
self.time_mix_k = nn.Parameter(torch.ones(1, 1, d_model))
self.time_mix_r = nn.Parameter(torch.ones(1, 1, d_model))
# FFN layers
self.key = nn.Linear(d_model, self.hidden, bias=False)
self.receptance = nn.Linear(d_model, d_model, bias=False)
self.value = nn.Linear(self.hidden, d_model, bias=False)
def forward(self, x, x_prev):
# Time-shift mixing
xk = x * self.time_mix_k + x_prev * (1 - self.time_mix_k)
xr = x * self.time_mix_r + x_prev * (1 - self.time_mix_r)
# Channel mixing
k = self.key(xk)
k = torch.square(torch.relu(k)) # Squared ReLU activation
kv = self.value(k)
# Receptance gate
r = torch.sigmoid(self.receptance(xr))
return r * kv
```
## RWKV Block Structure
```python
class RWKV_Block(nn.Module):
def __init__(self, d_model, n_layer):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
self.att = RWKV_TimeMix(d_model, n_layer)
self.ffn = RWKV_ChannelMix(d_model)
def forward(self, x, state):
# Time-mixing with residual
att_out, new_state = self.att(self.ln1(x), state)
x = x + att_out
# Channel-mixing with residual
ffn_out = self.ffn(self.ln2(x), state[:, :, 2]) # Use x_prev from state
x = x + ffn_out
return x, new_state
# Full RWKV model
model = nn.Sequential(
Embedding(...),
*[RWKV_Block(d_model, i) for i in range(n_layers)],
LayerNorm(d_model),
LMHead(...)
)
```
## Time-Decay Mechanism
The **time_decay** parameter `w` controls how fast information decays:
```python
# Initialization (RWKV-4)
time_decay = torch.ones(n_layers, d_model)
for i in range(n_layers):
for j in range(d_model):
# Logarithmic spacing
ratio = (i + 1) / n_layers
time_decay[i, j] = -5.0 + 8.0 * ratio + 0.3 * (j / d_model)
# Effect on memory
w = -exp(time_decay) # Range: [-exp(-5), -exp(3)] ≈ [-0.007, -20]
# Smaller w = slower decay = longer memory
# Larger w = faster decay = shorter memory
```
**Layer-wise decay pattern**:
- Early layers (shallow): Fast decay, capture local patterns
- Later layers (deep): Slow decay, capture long-range dependencies
## Receptance Gate
The **receptance** mechanism controls information flow:
```python
r = sigmoid(receptance(x)) # Range [0, 1]
output = r * wkv # Gate the WKV output
# High receptance (r ≈ 1): Pass information through
# Low receptance (r ≈ 0): Block information
```
**Purpose**: Similar to LSTM forget gate, but learned per-token
## RWKV-4 vs RWKV-5 vs RWKV-6 vs RWKV-7
### RWKV-4 (Original)
```python
# Time-shift with previous token
xx = x * time_mix + x_prev * (1 - time_mix)
k, v, r = key(xx), value(xx), receptance(xx)
```
### RWKV-5 (2023)
```python
# Separate time-mix for k, v, r
xk = x * time_mix_k + x_prev * (1 - time_mix_k)
xv = x * time_mix_v + x_prev * (1 - time_mix_v)
xr = x * time_mix_r + x_prev * (1 - time_mix_r)
k, v, r = key(xk), value(xk), receptance(xr)
```
### RWKV-6 (2024)
- Added **multi-head time-mixing** (like multi-head attention)
- Separate time-decay per head
- Improved stability for large models
```python
# Per-head processing
for h in range(n_heads):
k_h = key[h](x) # Separate projection per head
w_h = time_decay[h] # Separate decay per head
wkv_h = wkv(k_h, v_h, w_h)
output = concat(wkv_0, wkv_1, ..., wkv_H)
```
### RWKV-7 (March 2025)
- **Multimodal support** (vision + language)
- Improved numerical stability
- Better scaling to 14B+ parameters
## Numerical Stability
### Issue: Exponential Overflow
```python
# Problem: exp(wkv) can overflow
wkv = exp(u * kv) / exp(u * k) # Can overflow!
```
### Solution: Log-space Computation
```python
# Stable implementation
log_wkv_num = u + log(kv) + log(aa)
log_wkv_den = u + log(k) + log(ab)
wkv = exp(log_wkv_num - log_wkv_den) # Numerically stable
```
### Gradient Clipping
```python
# Recommended for training stability
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
```
## State Management
### State Shape
```python
# For batch inference
state = torch.zeros(
batch_size,
n_layers,
4, # (att_aa, att_ab, att_x_prev, ffn_x_prev)
d_model
)
```
### State Initialization
```python
# Zero initialization (standard)
state = None # Model creates zero state
# Warm state (from previous conversation)
_, state = model.forward(previous_context, None)
# Use `state` for next turn
```
### State Serialization
```python
# Save conversation state
torch.save(state, 'conversation_state.pt')
# Resume conversation
state = torch.load('conversation_state.pt')
out, state = model.forward(new_tokens, state)
```
## Resources
- Paper (RWKV): https://arxiv.org/abs/2305.13048 (May 2023)
- Paper (RWKV-7): https://arxiv.org/abs/2503.14456 (March 2025)
- GitHub: https://github.com/BlinkDL/RWKV-LM
- Math derivation: https://wiki.rwkv.com/
- CUDA kernels: https://github.com/BlinkDL/RWKV-CUDA
@@ -0,0 +1,386 @@
# RWKV-7: Latest Improvements (March 2025)
## Overview
RWKV-7 is the latest version released in March 2025, introducing multimodal capabilities and improved scaling to 14B+ parameters.
**Paper**: https://arxiv.org/abs/2503.14456 (March 2025)
## Key Improvements Over RWKV-6
### 1. Enhanced Numerical Stability
**Problem in RWKV-6**:
```python
# Exponential operations could overflow for large models
att_aa = exp(w) * att_aa + k * v # Overflow risk!
```
**RWKV-7 Solution**:
```python
# Log-space computation with safe exponentiation
log_att_aa = log_softmax([log(k * v), log_w + log(att_aa)])
att_aa = exp(log_att_aa)
```
**Result**: Stable training up to 14B parameters (RWKV-6 struggled beyond 7B)
### 2. Improved Time-Decay Initialization
**RWKV-6**:
```python
# Simple logarithmic spacing
time_decay[i] = -5.0 + 8.0 * (i / n_layers)
```
**RWKV-7**:
```python
# Adaptive per-head decay with better range
for layer in range(n_layers):
for head in range(n_heads):
# Different heads specialize in different timescales
alpha = (layer / n_layers) ** 0.7 # Non-linear progression
beta = (head / n_heads) * 0.5
time_decay[layer, head] = -6.0 + 9.0 * alpha + beta
# Result: Better long/short-term memory balance
```
**Impact**: 15-20% perplexity improvement on long-context tasks
### 3. Multi-Head Time-Mixing Refinements
**RWKV-6 Multi-Head**:
```python
# Simple concatenation
heads = [head_i(x) for head_i in heads]
output = concat(heads)
```
**RWKV-7 Multi-Head**:
```python
# Attention-style output projection
heads = [head_i(x) for head_i in heads]
concat_heads = concat(heads)
output = output_proj(concat_heads) # Learnable mixing
# Plus: Per-head layer norm
for i, head in enumerate(heads):
heads[i] = head_norm[i](head) # Separate norm per head
```
**Result**: Better head specialization, 8-12% quality improvement
### 4. Rotary Position Encoding (RoPE) Integration
**New in RWKV-7**:
```python
class RWKV7_TimeMix(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.rope = RotaryEmbedding(d_model // n_heads)
def forward(self, x):
k = self.key(x) # (B, T, d_model)
v = self.value(x)
# Apply RoPE to keys
k = self.rope.rotate_queries_or_keys(k)
# WKV with position-aware keys
wkv = self.wkv(k, v)
return wkv
```
**Why useful**: Improves positional awareness without breaking O(n) complexity
### 5. RWKV-7 Block Structure
```python
class RWKV7_Block(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
# Multi-head time-mixing with RoPE
self.att = RWKV7_MultiHeadTimeMix(d_model, n_heads)
# Enhanced channel-mixing
self.ffn = RWKV7_ChannelMix(d_model, hidden_ratio=3.5) # Larger FFN
def forward(self, x, state):
# Pre-norm (like GPT)
att_out, new_state = self.att(self.ln1(x), state)
x = x + att_out
# FFN with gating
ffn_out = self.ffn(self.ln2(x))
x = x + ffn_out
return x, new_state
```
## Multimodal Capabilities
### Vision Encoder Integration
**Architecture**:
```python
class RWKV7_Multimodal(nn.Module):
def __init__(self):
super().__init__()
# Vision encoder (CLIP-style)
self.vision_encoder = VisionTransformer(
patch_size=14,
d_model=1024,
n_layers=24
)
# Projection to RWKV space
self.vision_proj = nn.Linear(1024, d_model)
# RWKV language model
self.rwkv = RWKV7_LanguageModel(d_model=2560, n_layers=40)
def forward(self, image, text, state=None):
# Encode image to patches
vision_tokens = self.vision_encoder(image) # (B, 256, 1024)
vision_tokens = self.vision_proj(vision_tokens) # (B, 256, 2560)
# Concatenate vision and text tokens
combined = torch.cat([vision_tokens, text], dim=1)
# Process with RWKV
out, state = self.rwkv(combined, state)
return out, state
```
### Vision-Language Tasks
**Image Captioning**:
```python
model = RWKV7_Multimodal()
# Encode image
image = load_image('cat.jpg')
vision_tokens = model.vision_encoder(image)
# Generate caption
state = None
_, state = model.rwkv(vision_tokens, state) # Process image
# Autoregressive caption generation
caption = []
for _ in range(max_length):
logits, state = model.rwkv(prev_token, state)
next_token = sample(logits)
caption.append(next_token)
```
**VQA (Visual Question Answering)**:
```python
# Question: "What color is the cat?"
question_tokens = tokenizer.encode("What color is the cat?")
# Process image + question
combined = torch.cat([vision_tokens, question_tokens], dim=1)
answer_logits, state = model.rwkv(combined, state)
# Answer: "orange"
```
### Training Multimodal RWKV-7
```python
# Pretrain vision encoder (CLIP-style)
train_vision_encoder(image_text_pairs)
# Freeze vision encoder
model.vision_encoder.requires_grad_(False)
# Train projection + RWKV
for batch in multimodal_dataloader:
images, captions = batch
# Forward
vision_tokens = model.vision_encoder(images)
vision_tokens = model.vision_proj(vision_tokens)
logits, _ = model.rwkv(
torch.cat([vision_tokens, captions[:, :-1]], dim=1),
state=None
)
# Loss (next token prediction)
loss = F.cross_entropy(
logits[:, vision_tokens.shape[1]:].reshape(-1, vocab_size),
captions.reshape(-1)
)
loss.backward()
optimizer.step()
```
## Scaling to 14B Parameters
### Model Configuration
| Model | Layers | d_model | n_heads | Params | Context | VRAM (FP16) |
|-------|--------|---------|---------|--------|---------|-------------|
| RWKV-7-1.5B | 24 | 2048 | 16 | 1.5B | Infinite | 3 GB |
| RWKV-7-3B | 32 | 2560 | 20 | 3B | Infinite | 6 GB |
| RWKV-7-7B | 32 | 4096 | 32 | 7B | Infinite | 14 GB |
| RWKV-7-14B | 40 | 5120 | 40 | 14B | Infinite | 28 GB |
### Training Efficiency Improvements
**RWKV-6 Training (7B)**:
- Speed: 45K tokens/sec (8× A100)
- Memory: 38 GB per GPU (4K sequence)
- Stability: Occasional loss spikes
**RWKV-7 Training (14B)**:
- Speed: 52K tokens/sec (8× A100) - **15% faster**
- Memory: 42 GB per GPU (4K sequence) - **Better utilization**
- Stability: No loss spikes - **Improved stability**
**Key optimization**: Fused CUDA kernels for multi-head WKV
### RWKV-7 vs GPT-3 (14B)
| Metric | RWKV-7-14B | GPT-3-13B | Advantage |
|--------|------------|-----------|-----------|
| Training Speed | 52K tok/s | 28K tok/s | 1.9× |
| Inference (2K ctx) | 6,100 tok/s | 1,800 tok/s | 3.4× |
| Inference (8K ctx) | 5,800 tok/s | 450 tok/s | **12.9×** |
| Memory (inference) | 28 GB | 52 GB | 1.9× |
| Perplexity (Pile) | 6.8 | 7.2 | +6% |
## Production Use Cases
### Microsoft Integration
**Windows Copilot** (Limited Release):
- Uses RWKV-7-3B for on-device inference
- 5-8× faster than GPT-2 with better quality
- Constant memory for infinite context
**Office 365** (Experimental):
- Document summarization with RWKV-7-7B
- Handles 100K+ token documents efficiently
- No KV cache storage needed
### NVIDIA NeMo Support
**NeMo Guardrails with RWKV-7**:
```python
from nemoguardrails import RailsConfig
from nemoguardrails.llm.providers import register_llm_provider
# Register RWKV-7 as LLM backend
register_llm_provider("rwkv7", RWKV7Provider)
config = RailsConfig.from_path("config/")
rails = LLMRails(config, llm_provider="rwkv7")
# Use for content moderation
response = rails.generate(user_input="...")
```
## Benchmarks (RWKV-7 vs RWKV-6)
### Language Modeling
| Dataset | RWKV-6-7B | RWKV-7-7B | Improvement |
|---------|-----------|-----------|-------------|
| Pile (val) | 7.8 | 7.1 | +9% |
| C4 | 9.3 | 8.6 | +8% |
| WikiText-103 | 8.4 | 7.7 | +8% |
| Lambada | 11.2 | 9.8 | +13% |
### Long-Context Tasks (32K context)
| Task | RWKV-6-7B | RWKV-7-7B | Improvement |
|------|-----------|-----------|-------------|
| QuALITY | 52.3 | 61.8 | +18% |
| Qasper | 38.1 | 46.7 | +23% |
| NarrativeQA | 41.2 | 49.5 | +20% |
**RWKV-7's improved time-decay** significantly helps long-context understanding
### Multimodal Benchmarks
| Task | RWKV-7-7B | LLaVA-7B | BLIP-2-7B |
|------|-----------|----------|-----------|
| VQAv2 | 74.2 | 78.5 | 82.1 |
| GQA | 58.3 | 62.1 | 65.4 |
| TextVQA | 51.2 | 58.2 | 60.8 |
| COCO Caption | 118.3 | 125.7 | 132.4 |
**Note**: RWKV-7 competitive but not SOTA on vision (vision-focused models still better)
## Migration from RWKV-6 to RWKV-7
### Model Conversion
```python
# Load RWKV-6 checkpoint
rwkv6_state = torch.load('rwkv6-7b.pth')
# Initialize RWKV-7 model
rwkv7_model = RWKV7_Model(d_model=4096, n_layers=32, n_heads=32)
# Convert weights (mostly compatible)
for key in rwkv6_state:
if 'time_mixing' in key:
# RWKV-7 uses multi-head, need to split
rwkv7_key = convert_key_to_multihead(key)
rwkv7_model.state_dict()[rwkv7_key].copy_(rwkv6_state[key])
else:
# Direct copy
rwkv7_model.state_dict()[key].copy_(rwkv6_state[key])
# Fine-tune on small dataset to adapt
finetune(rwkv7_model, small_dataset, epochs=1)
```
### State Compatibility
**RWKV-6 State**:
```python
state_v6 = (att_aa, att_ab, att_x_prev, ffn_x_prev) # 4 components
```
**RWKV-7 State** (Multi-head):
```python
state_v7 = (
att_aa_heads, # (n_heads, d_model//n_heads)
att_ab_heads, # (n_heads, d_model//n_heads)
att_x_prev,
ffn_x_prev
) # 4 components, but att_* are multi-head
```
**Conversion**:
```python
# Split RWKV-6 state into RWKV-7 multi-head state
def convert_state_v6_to_v7(state_v6, n_heads):
att_aa, att_ab, att_x_prev, ffn_x_prev = state_v6
d_head = att_aa.shape[-1] // n_heads
att_aa_heads = att_aa.view(-1, n_heads, d_head).transpose(0, 1)
att_ab_heads = att_ab.view(-1, n_heads, d_head).transpose(0, 1)
return (att_aa_heads, att_ab_heads, att_x_prev, ffn_x_prev)
```
## Resources
- **Paper**: https://arxiv.org/abs/2503.14456 (RWKV-7, March 2025)
- **GitHub**: https://github.com/BlinkDL/RWKV-LM (v7 branch)
- **Models**: https://huggingface.co/BlinkDL/rwkv-7-world
- **Multimodal Demo**: https://huggingface.co/spaces/BlinkDL/RWKV-7-Multimodal
- **Discord**: https://discord.gg/bDSBUMeFpc
- **Wiki**: https://wiki.rwkv.com/rwkv7
@@ -0,0 +1,369 @@
# RWKV State Management
## Understanding RWKV State
Unlike Transformers with KV cache, RWKV maintains a **fixed-size recurrent state** that summarizes all previous context.
### State Components
```python
state = {
'att_aa': torch.zeros(n_layers, d_model), # Attention numerator accumulator
'att_ab': torch.zeros(n_layers, d_model), # Attention denominator accumulator
'att_x_prev': torch.zeros(n_layers, d_model), # Previous x for time-mixing
'ffn_x_prev': torch.zeros(n_layers, d_model) # Previous x for channel-mixing
}
```
**Total state size**: `4 × n_layers × d_model` parameters
| Model | Layers | d_model | State Size |
|-------|--------|---------|------------|
| RWKV-169M | 12 | 768 | 37 KB |
| RWKV-430M | 24 | 1024 | 98 KB |
| RWKV-1.5B | 24 | 2048 | 196 KB |
| RWKV-3B | 32 | 2560 | 327 KB |
| RWKV-7B | 32 | 4096 | 524 KB |
| RWKV-14B | 40 | 5120 | 819 KB |
**Constant memory** regardless of context length!
## State Initialization
### Zero State (Default)
```python
from rwkv.model import RWKV
model = RWKV(model='/path/to/RWKV-4-Pile-1B5', strategy='cuda fp16')
# Start with zero state (no context)
state = None
out, state = model.forward(tokens, state)
```
### Warm State (Preloaded Context)
```python
# Load context once
context = "The capital of France is Paris. The capital of Germany is Berlin."
context_tokens = tokenizer.encode(context)
# Process context to build state
state = None
for token in context_tokens:
_, state = model.forward([token], state)
# Now use warm state for queries
query = " The capital of Italy is"
query_tokens = tokenizer.encode(query)
out, state = model.forward(query_tokens, state)
# Model "remembers" Paris and Berlin examples!
```
### Shared State (Multi-turn Conversations)
```python
# Conversation with persistent state
state = None
# Turn 1
user1 = "My name is Alice."
tokens1 = tokenizer.encode(user1)
_, state = model.forward(tokens1, state)
# Turn 2
user2 = "What is my name?"
tokens2 = tokenizer.encode(user2)
response, state = model.forward(tokens2, state)
# Response: "Alice" (state remembers!)
```
## State Update Rules
### Time-Mixing State Update
```python
# Before processing token t
att_aa_t = att_aa_{t-1} # Previous numerator
att_ab_t = att_ab_{t-1} # Previous denominator
# Compute WKV
wkv_t = (exp(u) * k_t * v_t + att_aa_t) / (exp(u) * k_t + att_ab_t)
# Update state for token t+1
w = -exp(time_decay) # Decay factor
att_aa_{t+1} = exp(w) * att_aa_t + k_t * v_t
att_ab_{t+1} = exp(w) * att_ab_t + k_t
att_x_prev_{t+1} = x_t
```
**Effect of time_decay**:
- **w = -0.01** (small decay): State decays slowly → long memory
- **w = -5.0** (large decay): State decays quickly → short memory
### Channel-Mixing State Update
```python
# Simply store previous x for next token
ffn_x_prev_{t+1} = x_t
```
## State Serialization
### Save/Load State (PyTorch)
```python
import torch
# Save conversation state
state_dict = {
'att_aa': state[0],
'att_ab': state[1],
'att_x_prev': state[2],
'ffn_x_prev': state[3]
}
torch.save(state_dict, 'conversation_123.pt')
# Load state
loaded = torch.load('conversation_123.pt')
state = (loaded['att_aa'], loaded['att_ab'], loaded['att_x_prev'], loaded['ffn_x_prev'])
# Continue conversation
out, state = model.forward(new_tokens, state)
```
### State Compression (Optional)
```python
# FP16 state (half size)
state_fp16 = tuple(s.half() for s in state)
torch.save(state_fp16, 'state_compressed.pt')
# Restore
state = tuple(s.float() for s in torch.load('state_compressed.pt'))
```
## Multi-Session State Management
### Session State Store
```python
class StateManager:
def __init__(self):
self.sessions = {} # session_id -> state
def get_state(self, session_id):
return self.sessions.get(session_id, None)
def save_state(self, session_id, state):
self.sessions[session_id] = state
def clear_session(self, session_id):
if session_id in self.sessions:
del self.sessions[session_id]
# Usage
manager = StateManager()
# User 1 conversation
state1 = manager.get_state('user_1')
out1, state1 = model.forward(tokens1, state1)
manager.save_state('user_1', state1)
# User 2 conversation (independent state)
state2 = manager.get_state('user_2')
out2, state2 = model.forward(tokens2, state2)
manager.save_state('user_2', state2)
```
### State Expiration
```python
import time
class StateManagerWithExpiry:
def __init__(self, expiry_seconds=3600):
self.sessions = {} # session_id -> (state, timestamp)
self.expiry = expiry_seconds
def get_state(self, session_id):
if session_id in self.sessions:
state, timestamp = self.sessions[session_id]
if time.time() - timestamp < self.expiry:
return state
else:
del self.sessions[session_id] # Expired
return None
def save_state(self, session_id, state):
self.sessions[session_id] = (state, time.time())
```
## State Interpolation
### Blending States
```python
# Average two states (e.g., merging conversations)
def blend_states(state1, state2, alpha=0.5):
"""Blend state1 and state2 with weight alpha."""
return tuple(
alpha * s1 + (1 - alpha) * s2
for s1, s2 in zip(state1, state2)
)
# Example: Blend Alice and Bob conversation contexts
state_blended = blend_states(state_alice, state_bob, alpha=0.7)
# 70% Alice context, 30% Bob context
```
### State Editing
```python
# Manually edit state (advanced)
# Example: Reduce long-term memory influence
def decay_state(state, decay_factor=0.5):
"""Reduce state magnitude (forget older context)."""
att_aa, att_ab, att_x_prev, ffn_x_prev = state
return (
att_aa * decay_factor,
att_ab * decay_factor,
att_x_prev, # Keep recent x
ffn_x_prev # Keep recent x
)
# Usage
state = decay_state(state, decay_factor=0.3) # Forget 70% of history
```
## Batch Inference with States
### Independent Batch States
```python
# Each sequence in batch has separate state
batch_size = 4
states = [None] * batch_size
for i, tokens in enumerate(batch_sequences):
out, states[i] = model.forward(tokens, states[i])
```
### Shared Prefix Optimization
```python
# All sequences share common prefix (e.g., system prompt)
prefix = "You are a helpful assistant."
prefix_tokens = tokenizer.encode(prefix)
# Compute prefix state once
prefix_state = None
_, prefix_state = model.forward(prefix_tokens, None)
# Clone prefix state for each sequence
states = [prefix_state] * batch_size
# Process user queries (independent)
for i, user_query in enumerate(user_queries):
tokens = tokenizer.encode(user_query)
out, states[i] = model.forward(tokens, states[i])
```
## State Debugging
### Inspect State Magnitudes
```python
def inspect_state(state):
"""Print state statistics for debugging."""
att_aa, att_ab, att_x_prev, ffn_x_prev = state
print("State magnitudes:")
print(f" att_aa: mean={att_aa.abs().mean():.4f}, max={att_aa.abs().max():.4f}")
print(f" att_ab: mean={att_ab.abs().mean():.4f}, max={att_ab.abs().max():.4f}")
print(f" att_x_prev: mean={att_x_prev.abs().mean():.4f}, max={att_x_prev.abs().max():.4f}")
print(f" ffn_x_prev: mean={ffn_x_prev.abs().mean():.4f}, max={ffn_x_prev.abs().max():.4f}")
# Usage
inspect_state(state)
```
**Healthy ranges**:
- `att_aa`, `att_ab`: 0.1 - 10.0 (if much larger, may overflow)
- `att_x_prev`, `ffn_x_prev`: Similar to input embedding magnitude
### State Divergence Check
```python
def state_distance(state1, state2):
"""Compute L2 distance between two states."""
return sum(
torch.dist(s1, s2).item()
for s1, s2 in zip(state1, state2)
)
# Example: Check if states diverged
distance = state_distance(state_alice, state_bob)
print(f"State distance: {distance:.2f}")
# Large distance → very different contexts
```
## Numerical Stability Considerations
### Overflow Prevention
```python
# Issue: att_aa, att_ab can grow unbounded
# If att_aa > 1e10, numerical precision issues
# Solution 1: Periodic normalization
if att_aa.abs().max() > 1e6:
scale = att_aa.abs().max()
att_aa = att_aa / scale
att_ab = att_ab / scale
```
### Underflow Prevention
```python
# Issue: With large negative time_decay, state can underflow to 0
# Solution: Clip time_decay
time_decay = torch.clamp(time_decay, min=-8.0, max=-0.1)
# Ensures state doesn't decay too fast
```
## State vs KV Cache Comparison
### Memory Usage (8K context)
| Model Type | Model Size | KV Cache Size | RWKV State Size |
|------------|------------|---------------|-----------------|
| Transformer | 1.3B | 4.1 GB | - |
| **RWKV** | **1.5B** | **-** | **196 KB** |
| Transformer | 7B | 21.3 GB | - |
| **RWKV** | **7B** | **-** | **524 KB** |
**RWKV advantage**: 10,000× smaller than KV cache!
### Information Retention
**KV Cache (Transformer)**:
- Perfect: Stores all previous keys and values
- Retrieval: Exact attention to any previous token
- Cost: O(n) memory growth
**RWKV State**:
- Lossy: Compressed representation of history
- Retrieval: Weighted blend of previous tokens (decay-based)
- Cost: O(1) constant memory
**Trade-off**: RWKV sacrifices perfect recall for constant memory
## Resources
- State management examples: https://github.com/BlinkDL/ChatRWKV
- Wiki: https://wiki.rwkv.com/state-management
- Discord: https://discord.gg/bDSBUMeFpc (RWKV community)
@@ -0,0 +1,358 @@
---
name: distributed-llm-pretraining-torchtitan
description: Provides PyTorch-native distributed LLM pretraining using torchtitan with 4D parallelism (FSDP2, TP, PP, CP). Use when pretraining Llama 3.1, DeepSeek V3, or custom models at scale from 8 to 512+ GPUs with Float8, torch.compile, and distributed checkpointing.
version: 1.0.0
author: Orchestra Research
license: MIT
tags: [Model Architecture, Distributed Training, TorchTitan, FSDP2, Tensor Parallel, Pipeline Parallel, Context Parallel, Float8, Llama, Pretraining]
dependencies: [torch>=2.6.0, torchtitan>=0.2.0, torchao>=0.5.0]
---
# TorchTitan - PyTorch Native Distributed LLM Pretraining
## Quick start
TorchTitan is PyTorch's official platform for large-scale LLM pretraining with composable 4D parallelism (FSDP2, TP, PP, CP), achieving 65%+ speedups over baselines on H100 GPUs.
**Installation**:
```bash
# From PyPI (stable)
pip install torchtitan
# From source (latest features, requires PyTorch nightly)
git clone https://github.com/pytorch/torchtitan
cd torchtitan
pip install -r requirements.txt
```
**Download tokenizer**:
```bash
# Get HF token from https://huggingface.co/settings/tokens
python scripts/download_hf_assets.py --repo_id meta-llama/Llama-3.1-8B --assets tokenizer --hf_token=...
```
**Start training on 8 GPUs**:
```bash
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh
```
## Common workflows
### Workflow 1: Pretrain Llama 3.1 8B on single node
Copy this checklist:
```
Single Node Pretraining:
- [ ] Step 1: Download tokenizer
- [ ] Step 2: Configure training
- [ ] Step 3: Launch training
- [ ] Step 4: Monitor and checkpoint
```
**Step 1: Download tokenizer**
```bash
python scripts/download_hf_assets.py \
--repo_id meta-llama/Llama-3.1-8B \
--assets tokenizer \
--hf_token=YOUR_HF_TOKEN
```
**Step 2: Configure training**
Edit or create a TOML config file:
```toml
# llama3_8b_custom.toml
[job]
dump_folder = "./outputs"
description = "Llama 3.1 8B training"
[model]
name = "llama3"
flavor = "8B"
hf_assets_path = "./assets/hf/Llama-3.1-8B"
[optimizer]
name = "AdamW"
lr = 3e-4
[lr_scheduler]
warmup_steps = 200
[training]
local_batch_size = 2
seq_len = 8192
max_norm = 1.0
steps = 1000
dataset = "c4"
[parallelism]
data_parallel_shard_degree = -1 # Use all GPUs for FSDP
[activation_checkpoint]
mode = "selective"
selective_ac_option = "op"
[checkpoint]
enable = true
folder = "checkpoint"
interval = 500
```
**Step 3: Launch training**
```bash
# 8 GPUs on single node
CONFIG_FILE="./llama3_8b_custom.toml" ./run_train.sh
# Or explicitly with torchrun
torchrun --nproc_per_node=8 \
-m torchtitan.train \
--job.config_file ./llama3_8b_custom.toml
```
**Step 4: Monitor and checkpoint**
TensorBoard logs are saved to `./outputs/tb/`:
```bash
tensorboard --logdir ./outputs/tb
```
### Workflow 2: Multi-node training with SLURM
```
Multi-Node Training:
- [ ] Step 1: Configure parallelism for scale
- [ ] Step 2: Set up SLURM script
- [ ] Step 3: Submit job
- [ ] Step 4: Resume from checkpoint
```
**Step 1: Configure parallelism for scale**
For 70B model on 256 GPUs (32 nodes):
```toml
[parallelism]
data_parallel_shard_degree = 32 # FSDP across 32 ranks
tensor_parallel_degree = 8 # TP within node
pipeline_parallel_degree = 1 # No PP for 70B
context_parallel_degree = 1 # Increase for long sequences
```
**Step 2: Set up SLURM script**
```bash
#!/bin/bash
#SBATCH --job-name=llama70b
#SBATCH --nodes=32
#SBATCH --ntasks-per-node=8
#SBATCH --gpus-per-node=8
srun torchrun \
--nnodes=32 \
--nproc_per_node=8 \
--rdzv_backend=c10d \
--rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \
-m torchtitan.train \
--job.config_file ./llama3_70b.toml
```
**Step 3: Submit job**
```bash
sbatch multinode_trainer.slurm
```
**Step 4: Resume from checkpoint**
Training auto-resumes if checkpoint exists in configured folder.
### Workflow 3: Enable Float8 training for H100s
Float8 provides 30-50% speedup on H100 GPUs.
```
Float8 Training:
- [ ] Step 1: Install torchao
- [ ] Step 2: Configure Float8
- [ ] Step 3: Launch with compile
```
**Step 1: Install torchao**
```bash
USE_CPP=0 pip install git+https://github.com/pytorch/ao.git
```
**Step 2: Configure Float8**
Add to your TOML config:
```toml
[model]
converters = ["quantize.linear.float8"]
[quantize.linear.float8]
enable_fsdp_float8_all_gather = true
precompute_float8_dynamic_scale_for_fsdp = true
filter_fqns = ["output"] # Exclude output layer
[compile]
enable = true
components = ["model", "loss"]
```
**Step 3: Launch with compile**
```bash
CONFIG_FILE="./llama3_8b.toml" ./run_train.sh \
--model.converters="quantize.linear.float8" \
--quantize.linear.float8.enable_fsdp_float8_all_gather \
--compile.enable
```
### Workflow 4: 4D parallelism for 405B models
```
4D Parallelism (FSDP + TP + PP + CP):
- [ ] Step 1: Create seed checkpoint
- [ ] Step 2: Configure 4D parallelism
- [ ] Step 3: Launch on 512 GPUs
```
**Step 1: Create seed checkpoint**
Required for consistent initialization across PP stages:
```bash
NGPU=1 CONFIG_FILE=./llama3_405b.toml ./run_train.sh \
--checkpoint.enable \
--checkpoint.create_seed_checkpoint \
--parallelism.data_parallel_shard_degree 1 \
--parallelism.tensor_parallel_degree 1 \
--parallelism.pipeline_parallel_degree 1
```
**Step 2: Configure 4D parallelism**
```toml
[parallelism]
data_parallel_shard_degree = 8 # FSDP
tensor_parallel_degree = 8 # TP within node
pipeline_parallel_degree = 8 # PP across nodes
context_parallel_degree = 1 # CP for long sequences
[training]
local_batch_size = 32
seq_len = 8192
```
**Step 3: Launch on 512 GPUs**
```bash
# 64 nodes x 8 GPUs = 512 GPUs
srun torchrun --nnodes=64 --nproc_per_node=8 \
-m torchtitan.train \
--job.config_file ./llama3_405b.toml
```
## When to use vs alternatives
**Use TorchTitan when:**
- Pretraining LLMs from scratch (8B to 405B+)
- Need PyTorch-native solution without third-party dependencies
- Require composable 4D parallelism (FSDP2, TP, PP, CP)
- Training on H100s with Float8 support
- Want interoperable checkpoints with torchtune/HuggingFace
**Use alternatives instead:**
- **Megatron-LM**: Maximum performance for NVIDIA-only deployments
- **DeepSpeed**: Broader ZeRO optimization ecosystem, inference support
- **Axolotl/TRL**: Fine-tuning rather than pretraining
- **LitGPT**: Educational, smaller-scale training
## Common issues
**Issue: Out of memory on large models**
Enable activation checkpointing and reduce batch size:
```toml
[activation_checkpoint]
mode = "full" # Instead of "selective"
[training]
local_batch_size = 1
```
Or use gradient accumulation:
```toml
[training]
local_batch_size = 1
global_batch_size = 32 # Accumulates gradients
```
**Issue: TP causes high memory with async collectives**
Set environment variable:
```bash
export TORCH_NCCL_AVOID_RECORD_STREAMS=1
```
**Issue: Float8 training not faster**
Float8 only benefits large GEMMs. Filter small layers:
```toml
[quantize.linear.float8]
filter_fqns = ["attention.wk", "attention.wv", "output", "auto_filter_small_kn"]
```
**Issue: Checkpoint loading fails after parallelism change**
Use DCP's resharding capability:
```bash
# Convert sharded checkpoint to single file
python -m torch.distributed.checkpoint.format_utils \
dcp_to_torch checkpoint/step-1000 checkpoint.pt
```
**Issue: Pipeline parallelism initialization**
Create seed checkpoint first (see Workflow 4, Step 1).
## Supported models
| Model | Sizes | Status |
|-------|-------|--------|
| Llama 3.1 | 8B, 70B, 405B | Production |
| Llama 4 | Various | Experimental |
| DeepSeek V3 | 16B, 236B, 671B (MoE) | Experimental |
| GPT-OSS | 20B, 120B (MoE) | Experimental |
| Qwen 3 | Various | Experimental |
| Flux | Diffusion | Experimental |
## Performance benchmarks (H100)
| Model | GPUs | Parallelism | TPS/GPU | Techniques |
|-------|------|-------------|---------|------------|
| Llama 8B | 8 | FSDP | 5,762 | Baseline |
| Llama 8B | 8 | FSDP+compile+FP8 | 8,532 | +48% |
| Llama 70B | 256 | FSDP+TP+AsyncTP | 876 | 2D parallel |
| Llama 405B | 512 | FSDP+TP+PP | 128 | 3D parallel |
## Advanced topics
**FSDP2 configuration**: See [references/fsdp.md](references/fsdp.md) for detailed FSDP2 vs FSDP1 comparison and ZeRO equivalents.
**Float8 training**: See [references/float8.md](references/float8.md) for tensorwise vs rowwise scaling recipes.
**Checkpointing**: See [references/checkpoint.md](references/checkpoint.md) for HuggingFace conversion and async checkpointing.
**Adding custom models**: See [references/custom-models.md](references/custom-models.md) for TrainSpec protocol.
## Resources
- GitHub: https://github.com/pytorch/torchtitan
- Paper: https://arxiv.org/abs/2410.06511
- ICLR 2025: https://iclr.cc/virtual/2025/poster/29620
- PyTorch Forum: https://discuss.pytorch.org/c/distributed/torchtitan/44
@@ -0,0 +1,181 @@
# Checkpointing in TorchTitan
TorchTitan uses PyTorch Distributed Checkpoint (DCP) for fault-tolerant, interoperable checkpointing.
## Basic Configuration
```toml
[checkpoint]
enable = true
folder = "checkpoint"
interval = 500
```
## Save Model Only (Smaller Checkpoints)
Exclude optimizer state and training metadata:
```toml
[checkpoint]
enable = true
last_save_model_only = true
export_dtype = "bfloat16" # Optional: export in lower precision
```
## Excluding Keys from Loading
Partial checkpoint loading for modified settings:
```toml
[checkpoint]
enable = true
exclude_from_loading = ["data_loader", "lr_scheduler"]
```
CLI equivalent:
```bash
--checkpoint.exclude_from_loading data_loader,lr_scheduler
```
## Creating Seed Checkpoints
Required for Pipeline Parallelism to ensure consistent initialization:
```bash
NGPU=1 CONFIG_FILE=<path_to_config> ./run_train.sh \
--checkpoint.enable \
--checkpoint.create_seed_checkpoint \
--parallelism.data_parallel_replicate_degree 1 \
--parallelism.data_parallel_shard_degree 1 \
--parallelism.tensor_parallel_degree 1 \
--parallelism.pipeline_parallel_degree 1 \
--parallelism.context_parallel_degree 1 \
--parallelism.expert_parallel_degree 1
```
This initializes on single CPU for reproducible initialization across any GPU count.
## Async Checkpointing
Reduce checkpoint overhead with async writes:
```toml
[checkpoint]
enable = true
async_mode = "async" # Options: "disabled", "async", "async_with_pinned_mem"
```
## HuggingFace Conversion
### During Training
Save directly in HuggingFace format:
```toml
[checkpoint]
last_save_in_hf = true
last_save_model_only = true
```
Load from HuggingFace:
```toml
[checkpoint]
initial_load_in_hf = true
[model]
hf_assets_path = "./path/to/hf/checkpoint"
```
### Offline Conversion
Convert without running training:
```bash
# HuggingFace -> TorchTitan
python ./scripts/checkpoint_conversion/convert_from_hf.py \
<input_dir> <output_dir> \
--model_name llama3 \
--model_flavor 8B
# TorchTitan -> HuggingFace
python ./scripts/checkpoint_conversion/convert_to_hf.py \
<input_dir> <output_dir> \
--hf_assets_path ./assets/hf/Llama3.1-8B \
--model_name llama3 \
--model_flavor 8B
```
### Example
```bash
python ./scripts/convert_from_hf.py \
~/.cache/huggingface/hub/models--meta-llama--Meta-Llama-3-8B/snapshots/8cde5ca8380496c9a6cc7ef3a8b46a0372a1d920/ \
./initial_load_path/ \
--model_name llama3 \
--model_flavor 8B
```
## Converting to Single .pt File
Convert DCP sharded checkpoint to single PyTorch file:
```bash
python -m torch.distributed.checkpoint.format_utils \
dcp_to_torch \
torchtitan/outputs/checkpoint/step-1000 \
checkpoint.pt
```
## Checkpoint Structure
DCP saves sharded checkpoints that can be resharded for different parallelism configurations:
```
checkpoint/
├── step-500/
│ ├── .metadata
│ ├── __0_0.distcp
│ ├── __0_1.distcp
│ └── ...
└── step-1000/
└── ...
```
## Resume Training
Training auto-resumes from the latest checkpoint in the configured folder. To resume from a specific step:
```toml
[checkpoint]
load_step = 500 # Resume from step 500
```
## Interoperability with TorchTune
Checkpoints saved with `last_save_model_only = true` can be loaded directly into [torchtune](https://github.com/pytorch/torchtune) for fine-tuning.
## Full Configuration Example
```toml
[checkpoint]
enable = true
folder = "checkpoint"
interval = 500
load_step = -1 # -1 = latest, or specify step number
last_save_model_only = true
export_dtype = "bfloat16"
async_mode = "async"
exclude_from_loading = []
last_save_in_hf = false
initial_load_in_hf = false
create_seed_checkpoint = false
```
## Best Practices
1. **Large models**: Use `async_mode = "async"` to overlap checkpoint saves with training
2. **Fine-tuning export**: Enable `last_save_model_only` and `export_dtype = "bfloat16"` for smaller files
3. **Pipeline parallelism**: Always create seed checkpoint first
4. **Debugging**: Save frequent checkpoints during development, reduce for production
5. **HF interop**: Use conversion scripts for offline conversion, direct save/load for training workflows
@@ -0,0 +1,258 @@
# Adding Custom Models to TorchTitan
This guide explains how to add a new model to TorchTitan following the established patterns.
## Directory Structure
```
torchtitan/models/your_model/
├── model/
│ ├── __init__.py
│ ├── args.py # Model arguments
│ ├── model.py # Model definition
│ └── state_dict_adapter.py # HF conversion (optional)
├── infra/
│ ├── __init__.py
│ ├── parallelize.py # TP, FSDP, compile application
│ └── pipeline.py # PP application (optional)
├── train_configs/
│ ├── debug_model.toml
│ └── your_model_XB.toml
├── __init__.py # TrainSpec registration
└── README.md
```
## Step 1: Define Model Arguments
Inherit from `BaseModelArgs`:
```python
# model/args.py
from torchtitan.protocols.model import BaseModelArgs
from dataclasses import dataclass
@dataclass
class YourModelArgs(BaseModelArgs):
dim: int = 4096
n_layers: int = 32
n_heads: int = 32
vocab_size: int = 128256
def get_nparams_and_flops(self, seq_len: int) -> tuple[int, int]:
"""Return (num_params, flops_per_token) for throughput calculation."""
nparams = self.vocab_size * self.dim + ... # Calculate params
flops = 6 * nparams # Approximate: 6 * params for forward+backward
return nparams, flops
def update_from_config(self, job_config) -> "YourModelArgs":
"""Update args from training config."""
# Override specific args from job_config if needed
return self
```
## Step 2: Define Model
Inherit from `ModelProtocol`:
```python
# model/model.py
import torch.nn as nn
from torchtitan.protocols.model import ModelProtocol
from .args import YourModelArgs
class YourModel(ModelProtocol):
def __init__(self, args: YourModelArgs):
super().__init__()
self.args = args
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
self.layers = nn.ModuleDict({
str(i): TransformerBlock(args) for i in range(args.n_layers)
})
self.norm = RMSNorm(args.dim)
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
h = self.tok_embeddings(tokens)
for layer in self.layers.values():
h = layer(h)
h = self.norm(h)
return self.output(h)
def init_weights(self):
"""Initialize weights recursively."""
for module in self.modules():
if hasattr(module, 'init_weights') and module is not self:
module.init_weights()
elif isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=0.02)
```
**Important guidelines**:
- Write single-device model code (parallelism applied externally)
- Use `nn.ModuleDict` for layers (preserves FQNs when deleting for PP)
- Make input/output layers optional for PP compatibility
- Define `init_weights()` recursively
## Step 3: Parallelize Function
```python
# infra/parallelize.py
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.tensor.parallel import parallelize_module
def parallelize_your_model(
model: YourModel,
world_mesh: DeviceMesh,
parallel_dims: ParallelDims,
job_config: JobConfig,
):
# Apply in this order: TP -> AC -> compile -> FSDP
# 1. Tensor Parallelism
if parallel_dims.tp_enabled:
apply_tp(model, world_mesh["tp"], job_config)
# 2. Activation Checkpointing
if job_config.activation_checkpoint.mode == "full":
apply_ac(model, job_config)
# 3. torch.compile
if job_config.compile.enable:
model = torch.compile(model)
# 4. FSDP
if parallel_dims.dp_enabled:
apply_fsdp(model, world_mesh["dp"], job_config)
return model
```
## Step 4: Create TrainSpec
```python
# __init__.py
from torchtitan.protocols.train_spec import TrainSpec, register_train_spec
from .model.model import YourModel
from .model.args import YourModelArgs
from .infra.parallelize import parallelize_your_model
MODEL_CONFIGS = {
"8B": YourModelArgs(dim=4096, n_layers=32, n_heads=32),
"70B": YourModelArgs(dim=8192, n_layers=80, n_heads=64),
}
def get_train_spec(flavor: str) -> TrainSpec:
return TrainSpec(
model_cls=YourModel,
model_args=MODEL_CONFIGS[flavor],
parallelize_fn=parallelize_your_model,
pipeline_fn=None, # Or your_pipeline_fn for PP
build_optimizer_fn=build_optimizer, # Reuse existing
build_lr_scheduler_fn=build_lr_scheduler, # Reuse existing
build_dataloader_fn=build_dataloader, # Reuse existing
build_tokenizer_fn=build_tokenizer, # Reuse existing
build_loss_fn=build_loss, # Reuse existing
state_dict_adapter=None, # Or YourStateDictAdapter
)
# Register so train.py can find it
register_train_spec("your_model", get_train_spec)
```
## Step 5: State Dict Adapter (Optional)
For HuggingFace checkpoint conversion:
```python
# model/state_dict_adapter.py
from torchtitan.protocols.state_dict_adapter import BaseStateDictAdapter
class YourStateDictAdapter(BaseStateDictAdapter):
def to_hf(self, state_dict: dict) -> dict:
"""Convert torchtitan state dict to HF format."""
hf_state_dict = {}
for key, value in state_dict.items():
hf_key = self._convert_key_to_hf(key)
hf_state_dict[hf_key] = value
return hf_state_dict
def from_hf(self, state_dict: dict) -> dict:
"""Convert HF state dict to torchtitan format."""
tt_state_dict = {}
for key, value in state_dict.items():
tt_key = self._convert_key_from_hf(key)
tt_state_dict[tt_key] = value
return tt_state_dict
```
## Step 6: Training Config
```toml
# train_configs/your_model_8b.toml
[job]
dump_folder = "./outputs"
description = "Your Model 8B training"
[model]
name = "your_model"
flavor = "8B"
[optimizer]
name = "AdamW"
lr = 3e-4
[training]
local_batch_size = 2
seq_len = 8192
steps = 1000
dataset = "c4"
[parallelism]
data_parallel_shard_degree = -1
tensor_parallel_degree = 1
```
## Step 7: Register Model
Add to `torchtitan/models/__init__.py`:
```python
from .your_model import get_train_spec as get_your_model_train_spec
MODEL_REGISTRY["your_model"] = get_your_model_train_spec
```
## Testing
### Numerics Test
Compare output with HuggingFace implementation:
```python
def test_numerics():
# Load same checkpoint into both implementations
tt_model = YourModel(args).load_checkpoint(...)
hf_model = HFYourModel.from_pretrained(...)
# Compare outputs
input_ids = torch.randint(0, vocab_size, (1, 128))
tt_output = tt_model(input_ids)
hf_output = hf_model(input_ids).logits
torch.testing.assert_close(tt_output, hf_output, atol=1e-4, rtol=1e-4)
```
### Loss Convergence
Compare loss curves with verified baseline (see `docs/converging.md`).
### Performance Benchmark
Add benchmark config to `benchmarks/` folder.
## Guiding Principles
1. **Readability over flexibility**: Don't over-abstract
2. **Minimal model changes**: Parallelism applied externally
3. **Clean, minimal codebase**: Reuse existing components where possible
4. **Single-device semantics**: Model code should work on single GPU
@@ -0,0 +1,133 @@
# Float8 Training in TorchTitan
Float8 training provides substantial speedups for models where GEMMs are large enough that the FP8 tensorcore speedup outweighs dynamic quantization overhead.
## Hardware Requirements
- NVIDIA H100 or newer GPUs (FP8 Tensor Cores)
- Blackwell GPUs for MXFP8 training
## Installation
```bash
USE_CPP=0 pip install git+https://github.com/pytorch/ao.git
```
## Usage: Tensorwise Scaling
Standard Float8 with tensorwise dynamic scaling:
```bash
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh \
--model.converters="quantize.linear.float8" \
--quantize.linear.float8.enable_fsdp_float8_all_gather \
--quantize.linear.float8.precompute_float8_dynamic_scale_for_fsdp \
--compile.enable
```
### Key Arguments
| Argument | Description |
|----------|-------------|
| `--model.converters="quantize.linear.float8"` | Swap `nn.Linear` with `Float8Linear` |
| `--quantize.linear.float8.enable_fsdp_float8_all_gather` | Communicate in float8 to save bandwidth |
| `--quantize.linear.float8.precompute_float8_dynamic_scale_for_fsdp` | Single all-reduce for all AMAX/scales |
| `--compile.enable` | Required - fuses float8 scaling/casting kernels |
## Usage: Rowwise Scaling
Higher accuracy than tensorwise scaling:
```bash
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh \
--model.converters="quantize.linear.float8" \
--quantize.linear.float8.recipe_name rowwise \
--compile.enable
```
## Filtering Layers
Not all layers benefit from Float8. Filter small layers:
```bash
--quantize.linear.float8.filter_fqns="attention.wk,attention.wv,output"
```
### Auto-filtering
Automatically skip layers too small to benefit:
```bash
--quantize.linear.float8.filter_fqns="auto_filter_small_kn"
```
Thresholds based on H100 microbenchmarks where speedup > overhead.
## TOML Configuration
```toml
[model]
converters = ["quantize.linear.float8"]
[quantize.linear.float8]
enable_fsdp_float8_all_gather = true
precompute_float8_dynamic_scale_for_fsdp = true
filter_fqns = ["output", "auto_filter_small_kn"]
[compile]
enable = true
components = ["model", "loss"]
```
## How Float8 Works with Distributed Training
### Single Device
Cast input and weight to float8 inside forward before calling `torch._scaled_mm`:
```python
# Float8 matmul requires scales
torch._scaled_mm(input_fp8, weight_fp8, scale_a=scale_input, scale_b=scale_weight)
```
### FSDP + Float8
1. Cast sharded high-precision weights (1/N per rank) to float8
2. Perform float8 all-gather (saves bandwidth vs bf16/fp32)
3. Communicate `max(abs)` across ranks for scale computation
4. At forward start, have unsharded float8 weights ready
**Net benefit**: Float8 all-gather + amax communication can beat bf16/fp32 all-gather, depending on world size and message size.
### TP + Float8
- **Input**: Cast sharded input to float8, all-gather in float8
- **Weights**: Communicate `max(abs)` for sharded weights
- **Matmul**: Float8 input (unsharded) x float8 weight (sharded) with global scales
## Scaling Strategies
| Strategy | Status | Description |
|----------|--------|-------------|
| Tensorwise dynamic | Stable | Single scale per tensor |
| Rowwise dynamic | Alpha | Scale per row, higher accuracy |
## Performance Gains
From benchmarks on H100:
| Configuration | TPS/GPU | vs Baseline |
|---------------|---------|-------------|
| FSDP only | 5,762 | - |
| FSDP + compile | 6,667 | +16% |
| FSDP + compile + Float8 | 8,532 | +48% |
## Determining Float8 Benefit
Check [torchao microbenchmarks](https://github.com/pytorch/ao/tree/main/torchao/float8#performance) for forward+backward pass speedups on "layer norm => linear => sigmoid" for different M,N,K sizes.
Rule of thumb: GEMMs with K,N > 4096 typically benefit from Float8.
## MXFP8 Training (Blackwell)
For NVIDIA Blackwell GPUs, TorchTitan supports MXFP8 (Microscaling FP8) for both dense and MoE models. See [docs/mxfp8.md](https://github.com/pytorch/torchtitan/blob/main/docs/mxfp8.md) for details.
@@ -0,0 +1,126 @@
# FSDP2 in TorchTitan
## Why FSDP2?
FSDP2 is a rewrite of PyTorch's Fully Sharded Data Parallel (FSDP) API, removing the `FlatParameter` abstraction for better composability and simpler implementation.
### Key improvements over FSDP1
- **DTensor-based sharding**: Sharded parameters are `DTensor`s on dim-0, enabling easy manipulation and communication-free sharded state dicts
- **Better memory management**: Deterministic and lower GPU memory (7% reduction) by avoiding `recordStream`
- **Simplified API**: Fewer arguments, no wrapper class
### Performance
On Llama-7B with 8x H100s, FSDP2 achieves higher MFU with 7% lower peak memory than FSDP1, matching the same loss curve.
## API Reference
```python
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy, OffloadPolicy
@contract(state_cls=FSDPState)
def fully_shard(
module: nn.Module,
*,
mesh: Optional[DeviceMesh] = None,
reshard_after_forward: Union[bool, int] = True,
mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(),
offload_policy: OffloadPolicy = OffloadPolicy(),
) -> nn.Module:
```
## Sharding Strategies (ZeRO Equivalents)
| FSDP2 Configuration | FSDP1 Equivalent | DeepSpeed |
|---------------------|------------------|-----------|
| 1D mesh + `reshard_after_forward=True` | FULL_SHARD | ZeRO-3 |
| 1D mesh + `reshard_after_forward=False` | SHARD_GRAD_OP | ZeRO-2 |
| 2D mesh + `reshard_after_forward=True` | HYBRID_SHARD | MiCS |
| 1D/2D mesh + `reshard_after_forward=8` (int) | - | ZeRO++ hpZ |
## Meta-Device Initialization
FSDP2 supports materializing tensors onto GPU _after_ sharding:
```python
# Initialize on meta device (no memory)
with torch.device("meta"):
model = Transformer()
# Apply FSDP2 sharding
for module in model.modules():
if isinstance(module, TransformerBlock):
fully_shard(module)
fully_shard(model)
# Parameters still on meta device
for tensor in itertools.chain(model.parameters(), model.buffers()):
assert tensor.device == torch.device("meta")
# Allocate sharded parameters on GPU
model.to_empty(device="cuda")
# Initialize weights
model.init_weights()
```
## State Dict Differences
| Operation | FSDP1 | FSDP2 |
|-----------|-------|-------|
| `model.state_dict()` | Full state dict | Sharded state dict (no communication) |
| `optim.state_dict()` | Local state dict | Sharded state dict (no communication) |
| `summon_full_params()` | Supported | Use `DTensor` APIs like `full_tensor()` |
| Gradient clipping | `FSDP.clip_grad_norm_()` | `nn.utils.clip_grad_norm_()` |
## Mixed Precision
```python
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
output_dtype=torch.bfloat16,
cast_forward_inputs=True,
)
fully_shard(model, mp_policy=mp_policy)
```
## HSDP (Hybrid Sharded Data Parallel)
For 2D parallelism with replication + sharding:
```python
from torch.distributed.device_mesh import init_device_mesh
# Replicate across 4 groups, shard within 8 GPUs each
mesh = init_device_mesh("cuda", (4, 8), mesh_dim_names=("replicate", "shard"))
fully_shard(model, mesh=mesh)
```
## Configuration in TorchTitan
```toml
[parallelism]
# FSDP sharding degree (-1 = auto, use all available GPUs)
data_parallel_shard_degree = -1
# HSDP replication degree (1 = pure FSDP, >1 = HSDP)
data_parallel_replicate_degree = 1
```
## Removed Arguments from FSDP1
These FSDP1 arguments are no longer needed:
- `auto_wrap_policy`: Apply `fully_shard` directly to modules
- `backward_prefetch`: Always uses BACKWARD_PRE
- `param_init_fn`: Use meta-device initialization
- `device_id`: Uses mesh's device automatically
- `sync_module_states`: Not needed with DTensor
- `limit_all_gathers`: New memory management doesn't need it
- `use_orig_params`: Always true (no FlatParameter)
@@ -0,0 +1,5 @@
# Skills Coming Soon
This directory will contain high-quality AI research skills for tokenization.
See [CONTRIBUTING.md](../CONTRIBUTING.md) for how to contribute.
@@ -0,0 +1,516 @@
---
name: huggingface-tokenizers
description: Fast tokenizers optimized for research and production. Rust-based implementation tokenizes 1GB in <20 seconds. Supports BPE, WordPiece, and Unigram algorithms. Train custom vocabularies, track alignments, handle padding/truncation. Integrates seamlessly with transformers. Use when you need high-performance tokenization or custom tokenizer training.
version: 1.0.0
author: Orchestra Research
license: MIT
tags: [Tokenization, HuggingFace, BPE, WordPiece, Unigram, Fast Tokenization, Rust, Custom Tokenizer, Alignment Tracking, Production]
dependencies: [tokenizers, transformers, datasets]
---
# HuggingFace Tokenizers - Fast Tokenization for NLP
Fast, production-ready tokenizers with Rust performance and Python ease-of-use.
## When to use HuggingFace Tokenizers
**Use HuggingFace Tokenizers when:**
- Need extremely fast tokenization (<20s per GB of text)
- Training custom tokenizers from scratch
- Want alignment tracking (token → original text position)
- Building production NLP pipelines
- Need to tokenize large corpora efficiently
**Performance**:
- **Speed**: <20 seconds to tokenize 1GB on CPU
- **Implementation**: Rust core with Python/Node.js bindings
- **Efficiency**: 10-100× faster than pure Python implementations
**Use alternatives instead**:
- **SentencePiece**: Language-independent, used by T5/ALBERT
- **tiktoken**: OpenAI's BPE tokenizer for GPT models
- **transformers AutoTokenizer**: Loading pretrained only (uses this library internally)
## Quick start
### Installation
```bash
# Install tokenizers
pip install tokenizers
# With transformers integration
pip install tokenizers transformers
```
### Load pretrained tokenizer
```python
from tokenizers import Tokenizer
# Load from HuggingFace Hub
tokenizer = Tokenizer.from_pretrained("bert-base-uncased")
# Encode text
output = tokenizer.encode("Hello, how are you?")
print(output.tokens) # ['hello', ',', 'how', 'are', 'you', '?']
print(output.ids) # [7592, 1010, 2129, 2024, 2017, 1029]
# Decode back
text = tokenizer.decode(output.ids)
print(text) # "hello, how are you?"
```
### Train custom BPE tokenizer
```python
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
# Initialize tokenizer with BPE model
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
tokenizer.pre_tokenizer = Whitespace()
# Configure trainer
trainer = BpeTrainer(
vocab_size=30000,
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
min_frequency=2
)
# Train on files
files = ["train.txt", "validation.txt"]
tokenizer.train(files, trainer)
# Save
tokenizer.save("my-tokenizer.json")
```
**Training time**: ~1-2 minutes for 100MB corpus, ~10-20 minutes for 1GB
### Batch encoding with padding
```python
# Enable padding
tokenizer.enable_padding(pad_id=3, pad_token="[PAD]")
# Encode batch
texts = ["Hello world", "This is a longer sentence"]
encodings = tokenizer.encode_batch(texts)
for encoding in encodings:
print(encoding.ids)
# [101, 7592, 2088, 102, 3, 3, 3]
# [101, 2023, 2003, 1037, 2936, 6251, 102]
```
## Tokenization algorithms
### BPE (Byte-Pair Encoding)
**How it works**:
1. Start with character-level vocabulary
2. Find most frequent character pair
3. Merge into new token, add to vocabulary
4. Repeat until vocabulary size reached
**Used by**: GPT-2, GPT-3, RoBERTa, BART, DeBERTa
```python
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import ByteLevel
tokenizer = Tokenizer(BPE(unk_token="<|endoftext|>"))
tokenizer.pre_tokenizer = ByteLevel()
trainer = BpeTrainer(
vocab_size=50257,
special_tokens=["<|endoftext|>"],
min_frequency=2
)
tokenizer.train(files=["data.txt"], trainer=trainer)
```
**Advantages**:
- Handles OOV words well (breaks into subwords)
- Flexible vocabulary size
- Good for morphologically rich languages
**Trade-offs**:
- Tokenization depends on merge order
- May split common words unexpectedly
### WordPiece
**How it works**:
1. Start with character vocabulary
2. Score merge pairs: `frequency(pair) / (frequency(first) × frequency(second))`
3. Merge highest scoring pair
4. Repeat until vocabulary size reached
**Used by**: BERT, DistilBERT, MobileBERT
```python
from tokenizers import Tokenizer
from tokenizers.models import WordPiece
from tokenizers.trainers import WordPieceTrainer
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.normalizers import BertNormalizer
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
tokenizer.normalizer = BertNormalizer(lowercase=True)
tokenizer.pre_tokenizer = Whitespace()
trainer = WordPieceTrainer(
vocab_size=30522,
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
continuing_subword_prefix="##"
)
tokenizer.train(files=["corpus.txt"], trainer=trainer)
```
**Advantages**:
- Prioritizes meaningful merges (high score = semantically related)
- Used successfully in BERT (state-of-the-art results)
**Trade-offs**:
- Unknown words become `[UNK]` if no subword match
- Saves vocabulary, not merge rules (larger files)
### Unigram
**How it works**:
1. Start with large vocabulary (all substrings)
2. Compute loss for corpus with current vocabulary
3. Remove tokens with minimal impact on loss
4. Repeat until vocabulary size reached
**Used by**: ALBERT, T5, mBART, XLNet (via SentencePiece)
```python
from tokenizers import Tokenizer
from tokenizers.models import Unigram
from tokenizers.trainers import UnigramTrainer
tokenizer = Tokenizer(Unigram())
trainer = UnigramTrainer(
vocab_size=8000,
special_tokens=["<unk>", "<s>", "</s>"],
unk_token="<unk>"
)
tokenizer.train(files=["data.txt"], trainer=trainer)
```
**Advantages**:
- Probabilistic (finds most likely tokenization)
- Works well for languages without word boundaries
- Handles diverse linguistic contexts
**Trade-offs**:
- Computationally expensive to train
- More hyperparameters to tune
## Tokenization pipeline
Complete pipeline: **Normalization → Pre-tokenization → Model → Post-processing**
### Normalization
Clean and standardize text:
```python
from tokenizers.normalizers import NFD, StripAccents, Lowercase, Sequence
tokenizer.normalizer = Sequence([
NFD(), # Unicode normalization (decompose)
Lowercase(), # Convert to lowercase
StripAccents() # Remove accents
])
# Input: "Héllo WORLD"
# After normalization: "hello world"
```
**Common normalizers**:
- `NFD`, `NFC`, `NFKD`, `NFKC` - Unicode normalization forms
- `Lowercase()` - Convert to lowercase
- `StripAccents()` - Remove accents (é → e)
- `Strip()` - Remove whitespace
- `Replace(pattern, content)` - Regex replacement
### Pre-tokenization
Split text into word-like units:
```python
from tokenizers.pre_tokenizers import Whitespace, Punctuation, Sequence, ByteLevel
# Split on whitespace and punctuation
tokenizer.pre_tokenizer = Sequence([
Whitespace(),
Punctuation()
])
# Input: "Hello, world!"
# After pre-tokenization: ["Hello", ",", "world", "!"]
```
**Common pre-tokenizers**:
- `Whitespace()` - Split on spaces, tabs, newlines
- `ByteLevel()` - GPT-2 style byte-level splitting
- `Punctuation()` - Isolate punctuation
- `Digits(individual_digits=True)` - Split digits individually
- `Metaspace()` - Replace spaces with ▁ (SentencePiece style)
### Post-processing
Add special tokens for model input:
```python
from tokenizers.processors import TemplateProcessing
# BERT-style: [CLS] sentence [SEP]
tokenizer.post_processor = TemplateProcessing(
single="[CLS] $A [SEP]",
pair="[CLS] $A [SEP] $B [SEP]",
special_tokens=[
("[CLS]", 1),
("[SEP]", 2),
],
)
```
**Common patterns**:
```python
# GPT-2: sentence <|endoftext|>
TemplateProcessing(
single="$A <|endoftext|>",
special_tokens=[("<|endoftext|>", 50256)]
)
# RoBERTa: <s> sentence </s>
TemplateProcessing(
single="<s> $A </s>",
pair="<s> $A </s> </s> $B </s>",
special_tokens=[("<s>", 0), ("</s>", 2)]
)
```
## Alignment tracking
Track token positions in original text:
```python
output = tokenizer.encode("Hello, world!")
# Get token offsets
for token, offset in zip(output.tokens, output.offsets):
start, end = offset
print(f"{token:10} → [{start:2}, {end:2}): {text[start:end]!r}")
# Output:
# hello → [ 0, 5): 'Hello'
# , → [ 5, 6): ','
# world → [ 7, 12): 'world'
# ! → [12, 13): '!'
```
**Use cases**:
- Named entity recognition (map predictions back to text)
- Question answering (extract answer spans)
- Token classification (align labels to original positions)
## Integration with transformers
### Load with AutoTokenizer
```python
from transformers import AutoTokenizer
# AutoTokenizer automatically uses fast tokenizers
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# Check if using fast tokenizer
print(tokenizer.is_fast) # True
# Access underlying tokenizers.Tokenizer
fast_tokenizer = tokenizer.backend_tokenizer
print(type(fast_tokenizer)) # <class 'tokenizers.Tokenizer'>
```
### Convert custom tokenizer to transformers
```python
from tokenizers import Tokenizer
from transformers import PreTrainedTokenizerFast
# Train custom tokenizer
tokenizer = Tokenizer(BPE())
# ... train tokenizer ...
tokenizer.save("my-tokenizer.json")
# Wrap for transformers
transformers_tokenizer = PreTrainedTokenizerFast(
tokenizer_file="my-tokenizer.json",
unk_token="[UNK]",
pad_token="[PAD]",
cls_token="[CLS]",
sep_token="[SEP]",
mask_token="[MASK]"
)
# Use like any transformers tokenizer
outputs = transformers_tokenizer(
"Hello world",
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
)
```
## Common patterns
### Train from iterator (large datasets)
```python
from datasets import load_dataset
# Load dataset
dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="train")
# Create batch iterator
def batch_iterator(batch_size=1000):
for i in range(0, len(dataset), batch_size):
yield dataset[i:i + batch_size]["text"]
# Train tokenizer
tokenizer.train_from_iterator(
batch_iterator(),
trainer=trainer,
length=len(dataset) # For progress bar
)
```
**Performance**: Processes 1GB in ~10-20 minutes
### Enable truncation and padding
```python
# Enable truncation
tokenizer.enable_truncation(max_length=512)
# Enable padding
tokenizer.enable_padding(
pad_id=tokenizer.token_to_id("[PAD]"),
pad_token="[PAD]",
length=512 # Fixed length, or None for batch max
)
# Encode with both
output = tokenizer.encode("This is a long sentence that will be truncated...")
print(len(output.ids)) # 512
```
### Multi-processing
```python
from tokenizers import Tokenizer
from multiprocessing import Pool
# Load tokenizer
tokenizer = Tokenizer.from_file("tokenizer.json")
def encode_batch(texts):
return tokenizer.encode_batch(texts)
# Process large corpus in parallel
with Pool(8) as pool:
# Split corpus into chunks
chunk_size = 1000
chunks = [corpus[i:i+chunk_size] for i in range(0, len(corpus), chunk_size)]
# Encode in parallel
results = pool.map(encode_batch, chunks)
```
**Speedup**: 5-8× with 8 cores
## Performance benchmarks
### Training speed
| Corpus Size | BPE (30k vocab) | WordPiece (30k) | Unigram (8k) |
|-------------|-----------------|-----------------|--------------|
| 10 MB | 15 sec | 18 sec | 25 sec |
| 100 MB | 1.5 min | 2 min | 4 min |
| 1 GB | 15 min | 20 min | 40 min |
**Hardware**: 16-core CPU, tested on English Wikipedia
### Tokenization speed
| Implementation | 1 GB corpus | Throughput |
|----------------|-------------|---------------|
| Pure Python | ~20 minutes | ~50 MB/min |
| HF Tokenizers | ~15 seconds | ~4 GB/min |
| **Speedup** | **80×** | **80×** |
**Test**: English text, average sentence length 20 words
### Memory usage
| Task | Memory |
|-------------------------|---------|
| Load tokenizer | ~10 MB |
| Train BPE (30k vocab) | ~200 MB |
| Encode 1M sentences | ~500 MB |
## Supported models
Pre-trained tokenizers available via `from_pretrained()`:
**BERT family**:
- `bert-base-uncased`, `bert-large-cased`
- `distilbert-base-uncased`
- `roberta-base`, `roberta-large`
**GPT family**:
- `gpt2`, `gpt2-medium`, `gpt2-large`
- `distilgpt2`
**T5 family**:
- `t5-small`, `t5-base`, `t5-large`
- `google/flan-t5-xxl`
**Other**:
- `facebook/bart-base`, `facebook/mbart-large-cc25`
- `albert-base-v2`, `albert-xlarge-v2`
- `xlm-roberta-base`, `xlm-roberta-large`
Browse all: https://huggingface.co/models?library=tokenizers
## References
- **[Training Guide](references/training.md)** - Train custom tokenizers, configure trainers, handle large datasets
- **[Algorithms Deep Dive](references/algorithms.md)** - BPE, WordPiece, Unigram explained in detail
- **[Pipeline Components](references/pipeline.md)** - Normalizers, pre-tokenizers, post-processors, decoders
- **[Transformers Integration](references/integration.md)** - AutoTokenizer, PreTrainedTokenizerFast, special tokens
## Resources
- **Docs**: https://huggingface.co/docs/tokenizers
- **GitHub**: https://github.com/huggingface/tokenizers ⭐ 9,000+
- **Version**: 0.20.0+
- **Course**: https://huggingface.co/learn/nlp-course/chapter6/1
- **Paper**: BPE (Sennrich et al., 2016), WordPiece (Schuster & Nakajima, 2012)
@@ -0,0 +1,653 @@
# Tokenization Algorithms Deep Dive
Comprehensive explanation of BPE, WordPiece, and Unigram algorithms.
## Byte-Pair Encoding (BPE)
### Algorithm overview
BPE iteratively merges the most frequent pair of tokens in a corpus.
**Training process**:
1. Initialize vocabulary with all characters
2. Count frequency of all adjacent token pairs
3. Merge most frequent pair into new token
4. Add new token to vocabulary
5. Update corpus with new token
6. Repeat until vocabulary size reached
### Step-by-step example
**Corpus**:
```
low: 5
lower: 2
newest: 6
widest: 3
```
**Iteration 1**:
```
Count pairs:
'e' + 's': 9 (newest: 6, widest: 3) ← most frequent
'l' + 'o': 7
'o' + 'w': 7
...
Merge: 'e' + 's' → 'es'
Updated corpus:
low: 5
lower: 2
newest: 6 → newes|t: 6
widest: 3 → wides|t: 3
Vocabulary: [a-z] + ['es']
```
**Iteration 2**:
```
Count pairs:
'es' + 't': 9 ← most frequent
'l' + 'o': 7
...
Merge: 'es' + 't' → 'est'
Updated corpus:
low: 5
lower: 2
newest: 6 → new|est: 6
widest: 3 → wid|est: 3
Vocabulary: [a-z] + ['es', 'est']
```
**Continue until desired vocabulary size...**
### Tokenization with trained BPE
Given vocabulary: `['l', 'o', 'w', 'e', 'r', 'n', 's', 't', 'i', 'd', 'es', 'est', 'lo', 'low', 'ne', 'new', 'newest', 'wi', 'wid', 'widest']`
Tokenize "lowest":
```
Step 1: Split into characters
['l', 'o', 'w', 'e', 's', 't']
Step 2: Apply merges in order learned during training
- Merge 'l' + 'o' → 'lo' (if this merge was learned)
- Merge 'lo' + 'w' → 'low' (if learned)
- Merge 'e' + 's' → 'es' (learned)
- Merge 'es' + 't' → 'est' (learned)
Final: ['low', 'est']
```
### Implementation
```python
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
# Initialize
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
tokenizer.pre_tokenizer = Whitespace()
# Configure trainer
trainer = BpeTrainer(
vocab_size=1000,
min_frequency=2,
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
)
# Train
corpus = [
"This is a sample corpus for BPE training.",
"BPE learns subword units from the training data.",
# ... more sentences
]
tokenizer.train_from_iterator(corpus, trainer=trainer)
# Use
output = tokenizer.encode("This is tokenization")
print(output.tokens) # ['This', 'is', 'token', 'ization']
```
### Byte-level BPE (GPT-2 variant)
**Problem**: Standard BPE has limited character coverage (256+ Unicode chars)
**Solution**: Operate on byte level (256 bytes)
```python
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
tokenizer = Tokenizer(BPE())
# Byte-level pre-tokenization
tokenizer.pre_tokenizer = ByteLevel()
tokenizer.decoder = ByteLevelDecoder()
# This handles ALL possible characters, including emojis
text = "Hello 🌍 世界"
tokens = tokenizer.encode(text).tokens
```
**Advantages**:
- Handles any Unicode character (256 byte coverage)
- No unknown tokens (worst case: bytes)
- Used by GPT-2, GPT-3, BART
**Trade-offs**:
- Slightly worse compression (bytes vs characters)
- More tokens for non-ASCII text
### BPE variants
**SentencePiece BPE**:
- Language-independent (no pre-tokenization)
- Treats input as raw byte stream
- Used by T5, ALBERT, XLNet
**Robust BPE**:
- Dropout during training (randomly skip merges)
- More robust tokenization at inference
- Reduces overfitting to training data
## WordPiece
### Algorithm overview
WordPiece is similar to BPE but uses a different merge selection criterion.
**Training process**:
1. Initialize vocabulary with all characters
2. Count frequency of all token pairs
3. Score each pair: `score = freq(pair) / (freq(first) × freq(second))`
4. Merge pair with highest score
5. Repeat until vocabulary size reached
### Why different scoring?
**BPE**: Merges most frequent pairs
- "aa" appears 100 times → high priority
- Even if 'a' appears 1000 times alone
**WordPiece**: Merges pairs that are semantically related
- "aa" appears 100 times, 'a' appears 1000 times → low score (100 / (1000 × 1000))
- "th" appears 50 times, 't' appears 60 times, 'h' appears 55 times → high score (50 / (60 × 55))
- Prioritizes pairs that appear together more than expected
### Step-by-step example
**Corpus**:
```
low: 5
lower: 2
newest: 6
widest: 3
```
**Iteration 1**:
```
Count frequencies:
'e': 11 (lower: 2, newest: 6, widest: 3)
's': 9
't': 9
...
Count pairs:
'e' + 's': 9 (newest: 6, widest: 3)
'es' + 't': 9 (newest: 6, widest: 3)
...
Compute scores:
score('e' + 's') = 9 / (11 × 9) = 0.091
score('es' + 't') = 9 / (9 × 9) = 0.111 ← highest score
score('l' + 'o') = 7 / (7 × 9) = 0.111 ← tied
Choose: 'es' + 't' → 'est' (or 'lo' if tied)
```
**Key difference**: WordPiece prioritizes rare combinations over frequent ones.
### Tokenization with WordPiece
Given vocabulary: `['##e', '##s', '##t', 'l', 'o', 'w', 'new', 'est', 'low']`
Tokenize "lowest":
```
Step 1: Find longest matching prefix
'lowest' → 'low' (matches)
Step 2: Find longest match for remainder
'est' → 'est' (matches)
Final: ['low', 'est']
```
**If no match**:
```
Tokenize "unknownword":
'unknownword' → no match
'unknown' → no match
'unkn' → no match
'un' → no match
'u' → no match
→ [UNK]
```
### Implementation
```python
from tokenizers import Tokenizer
from tokenizers.models import WordPiece
from tokenizers.trainers import WordPieceTrainer
from tokenizers.normalizers import BertNormalizer
from tokenizers.pre_tokenizers import BertPreTokenizer
# Initialize BERT-style tokenizer
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
# Normalization (lowercase, accent stripping)
tokenizer.normalizer = BertNormalizer(lowercase=True)
# Pre-tokenization (whitespace + punctuation)
tokenizer.pre_tokenizer = BertPreTokenizer()
# Configure trainer
trainer = WordPieceTrainer(
vocab_size=30522, # BERT vocab size
min_frequency=2,
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
continuing_subword_prefix="##" # BERT uses ##
)
# Train
tokenizer.train_from_iterator(corpus, trainer=trainer)
# Use
output = tokenizer.encode("Tokenization works great!")
print(output.tokens) # ['token', '##ization', 'works', 'great', '!']
```
### Subword prefix
**BERT uses `##` prefix**:
```
"unbelievable" → ['un', '##believ', '##able']
```
**Why?**
- Indicates token is a continuation
- Allows reconstruction: remove ##, concatenate
- Helps model distinguish word boundaries
### WordPiece advantages
**Semantic merges**:
- Prioritizes meaningful combinations
- "qu" has high score (always together)
- "qx" has low score (rare combination)
**Better for morphology**:
- Captures affixes: un-, -ing, -ed
- Preserves word stems
**Trade-offs**:
- Slower training than BPE
- More memory (stores vocabulary, not merges)
- Original implementation not open-source (HF reimplementation)
## Unigram
### Algorithm overview
Unigram works backward: start with large vocabulary, remove tokens.
**Training process**:
1. Initialize with large vocabulary (all substrings)
2. Estimate probability of each token (frequency-based)
3. For each token, compute loss increase if removed
4. Remove 10-20% of tokens with lowest loss impact
5. Re-estimate probabilities
6. Repeat until desired vocabulary size
### Probabilistic tokenization
**Unigram assumption**: Each token is independent.
Given vocabulary with probabilities:
```
P('low') = 0.02
P('l') = 0.01
P('o') = 0.015
P('w') = 0.01
P('est') = 0.03
P('e') = 0.02
P('s') = 0.015
P('t') = 0.015
```
Tokenize "lowest":
```
Option 1: ['low', 'est']
P = P('low') × P('est') = 0.02 × 0.03 = 0.0006
Option 2: ['l', 'o', 'w', 'est']
P = 0.01 × 0.015 × 0.01 × 0.03 = 0.000000045
Option 3: ['low', 'e', 's', 't']
P = 0.02 × 0.02 × 0.015 × 0.015 = 0.0000009
Choose option 1 (highest probability)
```
### Viterbi algorithm
Finding best tokenization is expensive (exponential possibilities).
**Viterbi algorithm** (dynamic programming):
```python
def tokenize_viterbi(word, vocab, probs):
n = len(word)
# dp[i] = (best_prob, best_tokens) for word[:i]
dp = [{} for _ in range(n + 1)]
dp[0] = (0.0, []) # log probability
for i in range(1, n + 1):
best_prob = float('-inf')
best_tokens = []
# Try all possible last tokens
for j in range(i):
token = word[j:i]
if token in vocab:
prob = dp[j][0] + log(probs[token])
if prob > best_prob:
best_prob = prob
best_tokens = dp[j][1] + [token]
dp[i] = (best_prob, best_tokens)
return dp[n][1]
```
**Time complexity**: O(n² × vocab_size) vs O(2^n) brute force
### Implementation
```python
from tokenizers import Tokenizer
from tokenizers.models import Unigram
from tokenizers.trainers import UnigramTrainer
# Initialize
tokenizer = Tokenizer(Unigram())
# Configure trainer
trainer = UnigramTrainer(
vocab_size=8000,
special_tokens=["<unk>", "<s>", "</s>"],
unk_token="<unk>",
max_piece_length=16, # Max token length
n_sub_iterations=2, # EM iterations
shrinking_factor=0.75 # Remove 25% each iteration
)
# Train
tokenizer.train_from_iterator(corpus, trainer=trainer)
# Use
output = tokenizer.encode("Tokenization with Unigram")
print(output.tokens) # ['▁Token', 'ization', '▁with', '▁Un', 'igram']
```
### Unigram advantages
**Probabilistic**:
- Multiple valid tokenizations
- Can sample different tokenizations (data augmentation)
**Subword regularization**:
```python
# Sample different tokenizations
for _ in range(3):
tokens = tokenizer.encode("tokenization", is_pretokenized=False).tokens
print(tokens)
# Output (different each time):
# ['token', 'ization']
# ['tok', 'en', 'ization']
# ['token', 'iz', 'ation']
```
**Language-independent**:
- No word boundaries needed
- Works for CJK languages (Chinese, Japanese, Korean)
- Treats input as character stream
**Trade-offs**:
- Slower training (EM algorithm)
- More hyperparameters
- Larger model (stores probabilities)
## Algorithm comparison
### Training speed
| Algorithm | Small (10MB) | Medium (100MB) | Large (1GB) |
|------------|--------------|----------------|-------------|
| BPE | 10-15 sec | 1-2 min | 10-20 min |
| WordPiece | 15-20 sec | 2-3 min | 15-30 min |
| Unigram | 20-30 sec | 3-5 min | 30-60 min |
**Tested on**: 16-core CPU, 30k vocab
### Tokenization quality
Tested on English Wikipedia (perplexity measurement):
| Algorithm | Vocab Size | Tokens/Word | Unknown Rate |
|------------|------------|-------------|--------------|
| BPE | 30k | 1.3 | 0.5% |
| WordPiece | 30k | 1.2 | 1.2% |
| Unigram | 8k | 1.5 | 0.3% |
**Key observations**:
- WordPiece: Slightly better compression
- BPE: Lower unknown rate
- Unigram: Smallest vocab, good coverage
### Compression ratio
Characters per token (higher = better compression):
| Language | BPE (30k) | WordPiece (30k) | Unigram (8k) |
|----------|-----------|-----------------|--------------|
| English | 4.2 | 4.5 | 3.8 |
| Chinese | 2.1 | 2.3 | 2.5 |
| Arabic | 3.5 | 3.8 | 3.2 |
**Best for each**:
- English: WordPiece
- Chinese: Unigram (language-independent)
- Arabic: WordPiece
### Use case recommendations
**BPE** - Best for:
- English language models
- Code (handles symbols well)
- Fast training needed
- **Models**: GPT-2, GPT-3, RoBERTa, BART
**WordPiece** - Best for:
- Masked language modeling (BERT-style)
- Morphologically rich languages
- Semantic understanding tasks
- **Models**: BERT, DistilBERT, ELECTRA
**Unigram** - Best for:
- Multilingual models
- Languages without word boundaries (CJK)
- Data augmentation via subword regularization
- **Models**: T5, ALBERT, XLNet (via SentencePiece)
## Advanced topics
### Handling rare words
**BPE approach**:
```
"antidisestablishmentarianism"
→ ['anti', 'dis', 'establish', 'ment', 'arian', 'ism']
```
**WordPiece approach**:
```
"antidisestablishmentarianism"
→ ['anti', '##dis', '##establish', '##ment', '##arian', '##ism']
```
**Unigram approach**:
```
"antidisestablishmentarianism"
→ ['▁anti', 'dis', 'establish', 'ment', 'arian', 'ism']
```
### Handling numbers
**Challenge**: Infinite number combinations
**BPE solution**: Byte-level (handles any digit sequence)
```python
tokenizer = Tokenizer(BPE())
tokenizer.pre_tokenizer = ByteLevel()
# Handles any number
"123456789" byte-level tokens
```
**WordPiece solution**: Digit pre-tokenization
```python
from tokenizers.pre_tokenizers import Digits
# Split digits individually or as groups
tokenizer.pre_tokenizer = Digits(individual_digits=True)
"123" ['1', '2', '3']
```
**Unigram solution**: Learns common number patterns
```python
# Learns patterns during training
"2023" ['202', '3'] or ['20', '23']
```
### Handling case sensitivity
**Lowercase (BERT)**:
```python
from tokenizers.normalizers import Lowercase
tokenizer.normalizer = Lowercase()
"Hello WORLD" "hello world" ['hello', 'world']
```
**Preserve case (GPT-2)**:
```python
# No case normalization
tokenizer.normalizer = None
"Hello WORLD" ['Hello', 'WORLD']
```
**Cased tokens (RoBERTa)**:
```python
# Learns separate tokens for different cases
Vocabulary: ['Hello', 'hello', 'HELLO', 'world', 'WORLD']
```
### Handling emojis and special characters
**Byte-level (GPT-2)**:
```python
tokenizer.pre_tokenizer = ByteLevel()
"Hello 🌍 👋" byte-level representation (always works)
```
**Unicode normalization**:
```python
from tokenizers.normalizers import NFKC
tokenizer.normalizer = NFKC()
"é" (composed) "é" (decomposed) normalized to one form
```
## Troubleshooting
### Issue: Poor subword splitting
**Symptom**:
```
"running" → ['r', 'u', 'n', 'n', 'i', 'n', 'g'] (too granular)
```
**Solutions**:
1. Increase vocabulary size
2. Train longer (more merge iterations)
3. Lower `min_frequency` threshold
### Issue: Too many unknown tokens
**Symptom**:
```
5% of tokens are [UNK]
```
**Solutions**:
1. Increase vocabulary size
2. Use byte-level BPE (no UNK possible)
3. Verify training corpus is representative
### Issue: Inconsistent tokenization
**Symptom**:
```
"running" → ['run', 'ning']
"runner" → ['r', 'u', 'n', 'n', 'e', 'r']
```
**Solutions**:
1. Check normalization consistency
2. Ensure pre-tokenization is deterministic
3. Use Unigram for probabilistic variance
## Best practices
1. **Match algorithm to model architecture**:
- BERT-style → WordPiece
- GPT-style → BPE
- T5-style → Unigram
2. **Use byte-level for multilingual**:
- Handles any Unicode
- No unknown tokens
3. **Test on representative data**:
- Measure compression ratio
- Check unknown token rate
- Inspect sample tokenizations
4. **Version control tokenizers**:
- Save with model
- Document special tokens
- Track vocabulary changes
@@ -0,0 +1,637 @@
# Transformers Integration
Complete guide to using HuggingFace Tokenizers with the Transformers library.
## AutoTokenizer
The easiest way to load tokenizers.
### Loading pretrained tokenizers
```python
from transformers import AutoTokenizer
# Load from HuggingFace Hub
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# Check if using fast tokenizer (Rust-based)
print(tokenizer.is_fast) # True
# Access underlying tokenizers.Tokenizer
if tokenizer.is_fast:
fast_tokenizer = tokenizer.backend_tokenizer
print(type(fast_tokenizer)) # <class 'tokenizers.Tokenizer'>
```
### Fast vs slow tokenizers
| Feature | Fast (Rust) | Slow (Python) |
|--------------------------|----------------|---------------|
| Speed | 5-10× faster | Baseline |
| Alignment tracking | ✅ Full support | ❌ Limited |
| Batch processing | ✅ Optimized | ⚠️ Slower |
| Offset mapping | ✅ Yes | ❌ No |
| Installation | `tokenizers` | Built-in |
**Always use fast tokenizers when available.**
### Check available tokenizers
```python
from transformers import TOKENIZER_MAPPING
# List all fast tokenizers
for config_class, (slow, fast) in TOKENIZER_MAPPING.items():
if fast is not None:
print(f"{config_class.__name__}: {fast.__name__}")
```
## PreTrainedTokenizerFast
Wrap custom tokenizers for transformers.
### Convert custom tokenizer
```python
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from transformers import PreTrainedTokenizerFast
# Train custom tokenizer
tokenizer = Tokenizer(BPE())
trainer = BpeTrainer(
vocab_size=30000,
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
)
tokenizer.train(files=["corpus.txt"], trainer=trainer)
# Save tokenizer
tokenizer.save("my-tokenizer.json")
# Wrap for transformers
transformers_tokenizer = PreTrainedTokenizerFast(
tokenizer_file="my-tokenizer.json",
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]"
)
# Save in transformers format
transformers_tokenizer.save_pretrained("my-tokenizer")
```
**Result**: Directory with `tokenizer.json` + `tokenizer_config.json` + `special_tokens_map.json`
### Use like any transformers tokenizer
```python
# Load
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("my-tokenizer")
# Encode with all transformers features
outputs = tokenizer(
"Hello world",
padding="max_length",
truncation=True,
max_length=128,
return_tensors="pt"
)
print(outputs.keys())
# dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])
```
## Special tokens
### Default special tokens
| Model Family | CLS/BOS | SEP/EOS | PAD | UNK | MASK |
|--------------|---------|---------------|---------|---------|---------|
| BERT | [CLS] | [SEP] | [PAD] | [UNK] | [MASK] |
| GPT-2 | - | <\|endoftext\|> | <\|endoftext\|> | <\|endoftext\|> | - |
| RoBERTa | <s> | </s> | <pad> | <unk> | <mask> |
| T5 | - | </s> | <pad> | <unk> | - |
### Adding special tokens
```python
# Add new special tokens
special_tokens_dict = {
"additional_special_tokens": ["<|image|>", "<|video|>", "<|audio|>"]
}
num_added_tokens = tokenizer.add_special_tokens(special_tokens_dict)
print(f"Added {num_added_tokens} tokens")
# Resize model embeddings
model.resize_token_embeddings(len(tokenizer))
# Use new tokens
text = "This is an image: <|image|>"
tokens = tokenizer.encode(text)
```
### Adding regular tokens
```python
# Add domain-specific tokens
new_tokens = ["COVID-19", "mRNA", "vaccine"]
num_added = tokenizer.add_tokens(new_tokens)
# These are NOT special tokens (can be split if needed)
tokenizer.add_tokens(new_tokens, special_tokens=False)
# These ARE special tokens (never split)
tokenizer.add_tokens(new_tokens, special_tokens=True)
```
## Encoding and decoding
### Basic encoding
```python
# Single sentence
text = "Hello, how are you?"
encoded = tokenizer(text)
print(encoded)
# {'input_ids': [101, 7592, 1010, 2129, 2024, 2017, 1029, 102],
# 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0],
# 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}
```
### Batch encoding
```python
# Multiple sentences
texts = ["Hello world", "How are you?", "I am fine"]
encoded = tokenizer(texts, padding=True, truncation=True, max_length=10)
print(encoded['input_ids'])
# [[101, 7592, 2088, 102, 0, 0, 0, 0, 0, 0],
# [101, 2129, 2024, 2017, 1029, 102, 0, 0, 0, 0],
# [101, 1045, 2572, 2986, 102, 0, 0, 0, 0, 0]]
```
### Return tensors
```python
# Return PyTorch tensors
outputs = tokenizer("Hello world", return_tensors="pt")
print(outputs['input_ids'].shape) # torch.Size([1, 5])
# Return TensorFlow tensors
outputs = tokenizer("Hello world", return_tensors="tf")
# Return NumPy arrays
outputs = tokenizer("Hello world", return_tensors="np")
# Return lists (default)
outputs = tokenizer("Hello world", return_tensors=None)
```
### Decoding
```python
# Decode token IDs
ids = [101, 7592, 2088, 102]
text = tokenizer.decode(ids)
print(text) # "[CLS] hello world [SEP]"
# Skip special tokens
text = tokenizer.decode(ids, skip_special_tokens=True)
print(text) # "hello world"
# Batch decode
batch_ids = [[101, 7592, 102], [101, 2088, 102]]
texts = tokenizer.batch_decode(batch_ids, skip_special_tokens=True)
print(texts) # ["hello", "world"]
```
## Padding and truncation
### Padding strategies
```python
# Pad to max length in batch
tokenizer(texts, padding="longest")
# Pad to model max length
tokenizer(texts, padding="max_length", max_length=128)
# No padding
tokenizer(texts, padding=False)
# Pad to multiple of value (for efficient computation)
tokenizer(texts, padding="max_length", max_length=128, pad_to_multiple_of=8)
# Result: length will be 128 (already multiple of 8)
```
### Truncation strategies
```python
# Truncate to max length
tokenizer(text, truncation=True, max_length=10)
# Only truncate first sequence (for pairs)
tokenizer(text1, text2, truncation="only_first", max_length=20)
# Only truncate second sequence
tokenizer(text1, text2, truncation="only_second", max_length=20)
# Truncate longest first (default for pairs)
tokenizer(text1, text2, truncation="longest_first", max_length=20)
# No truncation (error if too long)
tokenizer(text, truncation=False)
```
### Stride for long documents
```python
# For documents longer than max_length
text = "Very long document " * 1000
# Encode with overlap
encodings = tokenizer(
text,
max_length=512,
stride=128, # Overlap between chunks
truncation=True,
return_overflowing_tokens=True,
return_offsets_mapping=True
)
# Get all chunks
num_chunks = len(encodings['input_ids'])
print(f"Split into {num_chunks} chunks")
# Each chunk overlaps by stride tokens
for i, chunk in enumerate(encodings['input_ids']):
print(f"Chunk {i}: {len(chunk)} tokens")
```
**Use case**: Long document QA, sliding window inference
## Alignment and offsets
### Offset mapping
```python
# Get character offsets for each token
encoded = tokenizer("Hello, world!", return_offsets_mapping=True)
for token, (start, end) in zip(
encoded.tokens(),
encoded['offset_mapping'][0]
):
print(f"{token:10s} → [{start:2d}, {end:2d})")
# Output:
# [CLS] → [ 0, 0)
# Hello → [ 0, 5)
# , → [ 5, 6)
# world → [ 7, 12)
# ! → [12, 13)
# [SEP] → [ 0, 0)
```
### Word IDs
```python
# Get word index for each token
encoded = tokenizer("Hello world", return_offsets_mapping=True)
word_ids = encoded.word_ids()
print(word_ids)
# [None, 0, 1, None]
# None = special token, 0 = first word, 1 = second word
```
**Use case**: Token classification (NER, POS tagging)
### Character to token mapping
```python
text = "Machine learning is awesome"
encoded = tokenizer(text, return_offsets_mapping=True)
# Find token for character position
char_pos = 8 # "l" in "learning"
token_idx = encoded.char_to_token(char_pos)
print(f"Character {char_pos} is in token {token_idx}: {encoded.tokens()[token_idx]}")
# Character 8 is in token 2: learning
```
**Use case**: Question answering (map answer character span to tokens)
### Sequence pairs
```python
# Encode sentence pair
encoded = tokenizer("Question here", "Answer here", return_offsets_mapping=True)
# Get sequence IDs (which sequence each token belongs to)
sequence_ids = encoded.sequence_ids()
print(sequence_ids)
# [None, 0, 0, 0, None, 1, 1, 1, None]
# None = special token, 0 = question, 1 = answer
```
## Model integration
### Use with transformers models
```python
from transformers import AutoModel, AutoTokenizer
import torch
# Load model and tokenizer
model = AutoModel.from_pretrained("bert-base-uncased")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# Tokenize
text = "Hello world"
inputs = tokenizer(text, return_tensors="pt")
# Forward pass
with torch.no_grad():
outputs = model(**inputs)
# Get embeddings
last_hidden_state = outputs.last_hidden_state
print(last_hidden_state.shape) # [1, seq_len, hidden_size]
```
### Custom model with custom tokenizer
```python
from transformers import BertConfig, BertModel
# Train custom tokenizer
from tokenizers import Tokenizer, models, trainers
tokenizer = Tokenizer(models.BPE())
trainer = trainers.BpeTrainer(vocab_size=30000)
tokenizer.train(files=["data.txt"], trainer=trainer)
# Wrap for transformers
from transformers import PreTrainedTokenizerFast
fast_tokenizer = PreTrainedTokenizerFast(
tokenizer_object=tokenizer,
unk_token="[UNK]",
pad_token="[PAD]"
)
# Create model with custom vocab size
config = BertConfig(vocab_size=30000)
model = BertModel(config)
# Use together
inputs = fast_tokenizer("Hello world", return_tensors="pt")
outputs = model(**inputs)
```
### Save and load together
```python
# Save both
model.save_pretrained("my-model")
tokenizer.save_pretrained("my-model")
# Directory structure:
# my-model/
# ├── config.json
# ├── pytorch_model.bin
# ├── tokenizer.json
# ├── tokenizer_config.json
# └── special_tokens_map.json
# Load both
from transformers import AutoModel, AutoTokenizer
model = AutoModel.from_pretrained("my-model")
tokenizer = AutoTokenizer.from_pretrained("my-model")
```
## Advanced features
### Multimodal tokenization
```python
from transformers import AutoTokenizer
# LLaVA-style (image + text)
tokenizer = AutoTokenizer.from_pretrained("llava-hf/llava-1.5-7b-hf")
# Add image placeholder token
tokenizer.add_special_tokens({"additional_special_tokens": ["<image>"]})
# Use in prompt
text = "Describe this image: <image>"
inputs = tokenizer(text, return_tensors="pt")
```
### Template formatting
```python
# Chat template
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi! How can I help?"},
{"role": "user", "content": "What's the weather?"}
]
# Apply chat template (if tokenizer has one)
if hasattr(tokenizer, "apply_chat_template"):
text = tokenizer.apply_chat_template(messages, tokenize=False)
inputs = tokenizer(text, return_tensors="pt")
```
### Custom template
```python
from transformers import PreTrainedTokenizerFast
tokenizer = PreTrainedTokenizerFast(tokenizer_file="tokenizer.json")
# Define chat template
tokenizer.chat_template = """
{%- for message in messages %}
{%- if message['role'] == 'system' %}
System: {{ message['content'] }}\\n
{%- elif message['role'] == 'user' %}
User: {{ message['content'] }}\\n
{%- elif message['role'] == 'assistant' %}
Assistant: {{ message['content'] }}\\n
{%- endif %}
{%- endfor %}
Assistant:
"""
# Use template
text = tokenizer.apply_chat_template(messages, tokenize=False)
```
## Performance optimization
### Batch processing
```python
# Process large datasets efficiently
from datasets import load_dataset
dataset = load_dataset("imdb", split="train[:1000]")
# Tokenize in batches
def tokenize_function(examples):
return tokenizer(
examples["text"],
padding="max_length",
truncation=True,
max_length=512
)
# Map over dataset (batched)
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
batch_size=1000,
num_proc=4 # Parallel processing
)
```
### Caching
```python
# Enable caching for repeated tokenization
tokenizer = AutoTokenizer.from_pretrained(
"bert-base-uncased",
use_fast=True,
cache_dir="./cache" # Cache tokenizer files
)
# Tokenize with caching
from functools import lru_cache
@lru_cache(maxsize=10000)
def cached_tokenize(text):
return tuple(tokenizer.encode(text))
# Reuses cached results for repeated inputs
```
### Memory efficiency
```python
# For very large datasets, use streaming
from datasets import load_dataset
dataset = load_dataset("pile", split="train", streaming=True)
def process_batch(batch):
# Tokenize
tokens = tokenizer(batch["text"], truncation=True, max_length=512)
# Process tokens...
return tokens
# Process in chunks (memory efficient)
for batch in dataset.batch(batch_size=1000):
processed = process_batch(batch)
```
## Troubleshooting
### Issue: Tokenizer not fast
**Symptom**:
```python
tokenizer.is_fast # False
```
**Solution**: Install tokenizers library
```bash
pip install tokenizers
```
### Issue: Special tokens not working
**Symptom**: Special tokens are split into subwords
**Solution**: Add as special tokens, not regular tokens
```python
# Wrong
tokenizer.add_tokens(["<|image|>"])
# Correct
tokenizer.add_special_tokens({"additional_special_tokens": ["<|image|>"]})
```
### Issue: Offset mapping not available
**Symptom**:
```python
tokenizer("text", return_offsets_mapping=True)
# Error: return_offsets_mapping not supported
```
**Solution**: Use fast tokenizer
```python
from transformers import AutoTokenizer
# Load fast version
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", use_fast=True)
```
### Issue: Padding inconsistent
**Symptom**: Some sequences padded, others not
**Solution**: Specify padding strategy
```python
# Explicit padding
tokenizer(
texts,
padding="max_length", # or "longest"
max_length=128
)
```
## Best practices
1. **Always use fast tokenizers**:
- 5-10× faster
- Full alignment tracking
- Better batch processing
2. **Save tokenizer with model**:
- Ensures reproducibility
- Prevents version mismatches
3. **Use batch processing for datasets**:
- Tokenize with `.map(batched=True)`
- Set `num_proc` for parallelism
4. **Enable caching for repeated inputs**:
- Use `lru_cache` for inference
- Cache tokenizer files with `cache_dir`
5. **Handle special tokens properly**:
- Use `add_special_tokens()` for never-split tokens
- Resize embeddings after adding tokens
6. **Test alignment for downstream tasks**:
- Verify `offset_mapping` is correct
- Test `char_to_token()` on samples
7. **Version control tokenizer config**:
- Save `tokenizer_config.json`
- Document custom templates
- Track vocabulary changes
@@ -0,0 +1,723 @@
# Tokenization Pipeline Components
Complete guide to normalizers, pre-tokenizers, models, post-processors, and decoders.
## Pipeline overview
**Full tokenization pipeline**:
```
Raw Text
Normalization (cleaning, lowercasing)
Pre-tokenization (split into words)
Model (apply BPE/WordPiece/Unigram)
Post-processing (add special tokens)
Token IDs
```
**Decoding reverses the process**:
```
Token IDs
Decoder (handle special encodings)
Raw Text
```
## Normalizers
Clean and standardize input text.
### Common normalizers
**Lowercase**:
```python
from tokenizers.normalizers import Lowercase
tokenizer.normalizer = Lowercase()
# Input: "Hello WORLD"
# Output: "hello world"
```
**Unicode normalization**:
```python
from tokenizers.normalizers import NFD, NFC, NFKD, NFKC
# NFD: Canonical decomposition
tokenizer.normalizer = NFD()
# "é" → "e" + "́" (separate characters)
# NFC: Canonical composition (default)
tokenizer.normalizer = NFC()
# "e" + "́" → "é" (composed)
# NFKD: Compatibility decomposition
tokenizer.normalizer = NFKD()
# "fi" → "f" + "i"
# NFKC: Compatibility composition
tokenizer.normalizer = NFKC()
# Most aggressive normalization
```
**Strip accents**:
```python
from tokenizers.normalizers import StripAccents
tokenizer.normalizer = StripAccents()
# Input: "café"
# Output: "cafe"
```
**Whitespace handling**:
```python
from tokenizers.normalizers import Strip, StripAccents
# Remove leading/trailing whitespace
tokenizer.normalizer = Strip()
# Input: " hello "
# Output: "hello"
```
**Replace patterns**:
```python
from tokenizers.normalizers import Replace
# Replace newlines with spaces
tokenizer.normalizer = Replace("\\n", " ")
# Input: "hello\\nworld"
# Output: "hello world"
```
### Combining normalizers
```python
from tokenizers.normalizers import Sequence, NFD, Lowercase, StripAccents
# BERT-style normalization
tokenizer.normalizer = Sequence([
NFD(), # Unicode decomposition
Lowercase(), # Convert to lowercase
StripAccents() # Remove accents
])
# Input: "Café au Lait"
# After NFD: "Café au Lait" (e + ́)
# After Lowercase: "café au lait"
# After StripAccents: "cafe au lait"
```
### Use case examples
**Case-insensitive model (BERT)**:
```python
from tokenizers.normalizers import BertNormalizer
# All-in-one BERT normalization
tokenizer.normalizer = BertNormalizer(
clean_text=True, # Remove control characters
handle_chinese_chars=True, # Add spaces around Chinese
strip_accents=True, # Remove accents
lowercase=True # Lowercase
)
```
**Case-sensitive model (GPT-2)**:
```python
# Minimal normalization
tokenizer.normalizer = NFC() # Only normalize Unicode
```
**Multilingual (mBERT)**:
```python
# Preserve scripts, normalize form
tokenizer.normalizer = NFKC()
```
## Pre-tokenizers
Split text into word-like units before tokenization.
### Whitespace splitting
```python
from tokenizers.pre_tokenizers import Whitespace
tokenizer.pre_tokenizer = Whitespace()
# Input: "Hello world! How are you?"
# Output: [("Hello", (0, 5)), ("world!", (6, 12)), ("How", (13, 16)), ("are", (17, 20)), ("you?", (21, 25))]
```
### Punctuation isolation
```python
from tokenizers.pre_tokenizers import Punctuation
tokenizer.pre_tokenizer = Punctuation()
# Input: "Hello, world!"
# Output: [("Hello", ...), (",", ...), ("world", ...), ("!", ...)]
```
### Byte-level (GPT-2)
```python
from tokenizers.pre_tokenizers import ByteLevel
tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=True)
# Input: "Hello world"
# Output: Byte-level tokens with Ġ prefix for spaces
# [("ĠHello", ...), ("Ġworld", ...)]
```
**Key feature**: Handles ALL Unicode characters (256 byte combinations)
### Metaspace (SentencePiece)
```python
from tokenizers.pre_tokenizers import Metaspace
tokenizer.pre_tokenizer = Metaspace(replacement="", add_prefix_space=True)
# Input: "Hello world"
# Output: [("▁Hello", ...), ("▁world", ...)]
```
**Used by**: T5, ALBERT (via SentencePiece)
### Digits splitting
```python
from tokenizers.pre_tokenizers import Digits
# Split digits individually
tokenizer.pre_tokenizer = Digits(individual_digits=True)
# Input: "Room 123"
# Output: [("Room", ...), ("1", ...), ("2", ...), ("3", ...)]
# Keep digits together
tokenizer.pre_tokenizer = Digits(individual_digits=False)
# Input: "Room 123"
# Output: [("Room", ...), ("123", ...)]
```
### BERT pre-tokenizer
```python
from tokenizers.pre_tokenizers import BertPreTokenizer
tokenizer.pre_tokenizer = BertPreTokenizer()
# Splits on whitespace and punctuation, preserves CJK
# Input: "Hello, 世界!"
# Output: [("Hello", ...), (",", ...), ("世", ...), ("界", ...), ("!", ...)]
```
### Combining pre-tokenizers
```python
from tokenizers.pre_tokenizers import Sequence, Whitespace, Punctuation
tokenizer.pre_tokenizer = Sequence([
Whitespace(), # Split on whitespace first
Punctuation() # Then isolate punctuation
])
# Input: "Hello, world!"
# After Whitespace: [("Hello,", ...), ("world!", ...)]
# After Punctuation: [("Hello", ...), (",", ...), ("world", ...), ("!", ...)]
```
### Pre-tokenizer comparison
| Pre-tokenizer | Use Case | Example |
|-------------------|---------------------------------|--------------------------------------------|
| Whitespace | Simple English | "Hello world" → ["Hello", "world"] |
| Punctuation | Isolate symbols | "world!" → ["world", "!"] |
| ByteLevel | Multilingual, emojis | "🌍" → byte tokens |
| Metaspace | SentencePiece-style | "Hello" → ["▁Hello"] |
| BertPreTokenizer | BERT-style (CJK aware) | "世界" → ["世", "界"] |
| Digits | Handle numbers | "123" → ["1", "2", "3"] or ["123"] |
## Models
Core tokenization algorithms.
### BPE Model
```python
from tokenizers.models import BPE
model = BPE(
vocab=None, # Or provide pre-built vocab
merges=None, # Or provide merge rules
unk_token="[UNK]", # Unknown token
continuing_subword_prefix="",
end_of_word_suffix="",
fuse_unk=False # Keep unknown tokens separate
)
tokenizer = Tokenizer(model)
```
**Parameters**:
- `vocab`: Dict of token → id
- `merges`: List of merge rules `["a b", "ab c"]`
- `unk_token`: Token for unknown words
- `continuing_subword_prefix`: Prefix for subwords (empty for GPT-2)
- `end_of_word_suffix`: Suffix for last subword (empty for GPT-2)
### WordPiece Model
```python
from tokenizers.models import WordPiece
model = WordPiece(
vocab=None,
unk_token="[UNK]",
max_input_chars_per_word=100, # Max word length
continuing_subword_prefix="##" # BERT-style prefix
)
tokenizer = Tokenizer(model)
```
**Key difference**: Uses `##` prefix for continuing subwords.
### Unigram Model
```python
from tokenizers.models import Unigram
model = Unigram(
vocab=None, # List of (token, score) tuples
unk_id=0, # ID for unknown token
byte_fallback=False # Fall back to bytes if no match
)
tokenizer = Tokenizer(model)
```
**Probabilistic**: Selects tokenization with highest probability.
### WordLevel Model
```python
from tokenizers.models import WordLevel
# Simple word-to-ID mapping (no subwords)
model = WordLevel(
vocab=None,
unk_token="[UNK]"
)
tokenizer = Tokenizer(model)
```
**Warning**: Requires huge vocabulary (one token per word).
## Post-processors
Add special tokens and format output.
### Template processing
**BERT-style** (`[CLS] sentence [SEP]`):
```python
from tokenizers.processors import TemplateProcessing
tokenizer.post_processor = TemplateProcessing(
single="[CLS] $A [SEP]",
pair="[CLS] $A [SEP] $B [SEP]",
special_tokens=[
("[CLS]", 101),
("[SEP]", 102),
],
)
# Single sentence
output = tokenizer.encode("Hello world")
# [101, ..., 102] ([CLS] hello world [SEP])
# Sentence pair
output = tokenizer.encode("Hello", "world")
# [101, ..., 102, ..., 102] ([CLS] hello [SEP] world [SEP])
```
**GPT-2 style** (`sentence <|endoftext|>`):
```python
tokenizer.post_processor = TemplateProcessing(
single="$A <|endoftext|>",
special_tokens=[
("<|endoftext|>", 50256),
],
)
```
**RoBERTa style** (`<s> sentence </s>`):
```python
tokenizer.post_processor = TemplateProcessing(
single="<s> $A </s>",
pair="<s> $A </s> </s> $B </s>",
special_tokens=[
("<s>", 0),
("</s>", 2),
],
)
```
**T5 style** (no special tokens):
```python
# T5 doesn't add special tokens via post-processor
tokenizer.post_processor = None
```
### RobertaProcessing
```python
from tokenizers.processors import RobertaProcessing
tokenizer.post_processor = RobertaProcessing(
sep=("</s>", 2),
cls=("<s>", 0),
add_prefix_space=True, # Add space before first token
trim_offsets=True # Trim leading space from offsets
)
```
### ByteLevelProcessing
```python
from tokenizers.processors import ByteLevel as ByteLevelProcessing
tokenizer.post_processor = ByteLevelProcessing(
trim_offsets=True # Remove Ġ from offsets
)
```
## Decoders
Convert token IDs back to text.
### ByteLevel decoder
```python
from tokenizers.decoders import ByteLevel
tokenizer.decoder = ByteLevel()
# Handles byte-level tokens
# ["ĠHello", "Ġworld"] → "Hello world"
```
### WordPiece decoder
```python
from tokenizers.decoders import WordPiece
tokenizer.decoder = WordPiece(prefix="##")
# Removes ## prefix and concatenates
# ["token", "##ization"] → "tokenization"
```
### Metaspace decoder
```python
from tokenizers.decoders import Metaspace
tokenizer.decoder = Metaspace(replacement="", add_prefix_space=True)
# Converts ▁ back to spaces
# ["▁Hello", "▁world"] → "Hello world"
```
### BPEDecoder
```python
from tokenizers.decoders import BPEDecoder
tokenizer.decoder = BPEDecoder(suffix="</w>")
# Removes suffix and concatenates
# ["token", "ization</w>"] → "tokenization"
```
### Sequence decoder
```python
from tokenizers.decoders import Sequence, ByteLevel, Strip
tokenizer.decoder = Sequence([
ByteLevel(), # Decode byte-level first
Strip(' ', 1, 1) # Strip leading/trailing spaces
])
```
## Complete pipeline examples
### BERT tokenizer
```python
from tokenizers import Tokenizer
from tokenizers.models import WordPiece
from tokenizers.normalizers import BertNormalizer
from tokenizers.pre_tokenizers import BertPreTokenizer
from tokenizers.processors import TemplateProcessing
from tokenizers.decoders import WordPiece as WordPieceDecoder
# Model
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
# Normalization
tokenizer.normalizer = BertNormalizer(lowercase=True)
# Pre-tokenization
tokenizer.pre_tokenizer = BertPreTokenizer()
# Post-processing
tokenizer.post_processor = TemplateProcessing(
single="[CLS] $A [SEP]",
pair="[CLS] $A [SEP] $B [SEP]",
special_tokens=[("[CLS]", 101), ("[SEP]", 102)],
)
# Decoder
tokenizer.decoder = WordPieceDecoder(prefix="##")
# Enable padding
tokenizer.enable_padding(pad_id=0, pad_token="[PAD]")
# Enable truncation
tokenizer.enable_truncation(max_length=512)
```
### GPT-2 tokenizer
```python
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.normalizers import NFC
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
from tokenizers.processors import TemplateProcessing
# Model
tokenizer = Tokenizer(BPE())
# Normalization (minimal)
tokenizer.normalizer = NFC()
# Byte-level pre-tokenization
tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=False)
# Post-processing
tokenizer.post_processor = TemplateProcessing(
single="$A <|endoftext|>",
special_tokens=[("<|endoftext|>", 50256)],
)
# Byte-level decoder
tokenizer.decoder = ByteLevelDecoder()
```
### T5 tokenizer (SentencePiece-style)
```python
from tokenizers import Tokenizer
from tokenizers.models import Unigram
from tokenizers.normalizers import NFKC
from tokenizers.pre_tokenizers import Metaspace
from tokenizers.decoders import Metaspace as MetaspaceDecoder
# Model
tokenizer = Tokenizer(Unigram())
# Normalization
tokenizer.normalizer = NFKC()
# Metaspace pre-tokenization
tokenizer.pre_tokenizer = Metaspace(replacement="", add_prefix_space=True)
# No post-processing (T5 doesn't add CLS/SEP)
tokenizer.post_processor = None
# Metaspace decoder
tokenizer.decoder = MetaspaceDecoder(replacement="", add_prefix_space=True)
```
## Alignment tracking
Track token positions in original text.
### Basic alignment
```python
text = "Hello, world!"
output = tokenizer.encode(text)
for token, (start, end) in zip(output.tokens, output.offsets):
print(f"{token:10s} → [{start:2d}, {end:2d}): {text[start:end]!r}")
# Output:
# [CLS] → [ 0, 0): ''
# hello → [ 0, 5): 'Hello'
# , → [ 5, 6): ','
# world → [ 7, 12): 'world'
# ! → [12, 13): '!'
# [SEP] → [ 0, 0): ''
```
### Word-level alignment
```python
# Get word_ids (which word each token belongs to)
encoding = tokenizer.encode("Hello world")
word_ids = encoding.word_ids
print(word_ids)
# [None, 0, 0, 1, None]
# None = special token, 0 = first word, 1 = second word
```
**Use case**: Token classification (NER)
```python
# Align predictions to words
predictions = ["O", "B-PER", "I-PER", "O", "O"]
word_predictions = {}
for token_idx, word_idx in enumerate(encoding.word_ids):
if word_idx is not None and word_idx not in word_predictions:
word_predictions[word_idx] = predictions[token_idx]
print(word_predictions)
# {0: "B-PER", 1: "O"} # First word is PERSON, second is OTHER
```
### Span alignment
```python
# Find token span for character span
text = "Machine learning is awesome"
char_start, char_end = 8, 16 # "learning"
encoding = tokenizer.encode(text)
# Find token span
token_start = encoding.char_to_token(char_start)
token_end = encoding.char_to_token(char_end - 1) + 1
print(f"Tokens {token_start}:{token_end} = {encoding.tokens[token_start:token_end]}")
# Tokens 2:3 = ['learning']
```
**Use case**: Question answering (extract answer span)
## Custom components
### Custom normalizer
```python
from tokenizers import NormalizedString, Normalizer
class CustomNormalizer:
def normalize(self, normalized: NormalizedString):
# Custom normalization logic
normalized.lowercase()
normalized.replace(" ", " ") # Replace double spaces
# Use custom normalizer
tokenizer.normalizer = CustomNormalizer()
```
### Custom pre-tokenizer
```python
from tokenizers import PreTokenizedString
class CustomPreTokenizer:
def pre_tokenize(self, pretok: PreTokenizedString):
# Custom pre-tokenization logic
pretok.split(lambda i, char: char.isspace())
tokenizer.pre_tokenizer = CustomPreTokenizer()
```
## Troubleshooting
### Issue: Misaligned offsets
**Symptom**: Offsets don't match original text
```python
text = " hello" # Leading spaces
offsets = [(0, 5)] # Expects " hel"
```
**Solution**: Check normalization strips spaces
```python
# Preserve offsets
tokenizer.normalizer = Sequence([
Strip(), # This changes offsets!
])
# Use trim_offsets in post-processor instead
tokenizer.post_processor = ByteLevelProcessing(trim_offsets=True)
```
### Issue: Special tokens not added
**Symptom**: No [CLS] or [SEP] in output
**Solution**: Check post-processor is set
```python
tokenizer.post_processor = TemplateProcessing(
single="[CLS] $A [SEP]",
special_tokens=[("[CLS]", 101), ("[SEP]", 102)],
)
```
### Issue: Incorrect decoding
**Symptom**: Decoded text has ## or ▁
**Solution**: Set correct decoder
```python
# For WordPiece
tokenizer.decoder = WordPieceDecoder(prefix="##")
# For SentencePiece
tokenizer.decoder = MetaspaceDecoder(replacement="")
```
## Best practices
1. **Match pipeline to model architecture**:
- BERT → BertNormalizer + BertPreTokenizer + WordPiece
- GPT-2 → NFC + ByteLevel + BPE
- T5 → NFKC + Metaspace + Unigram
2. **Test pipeline on sample inputs**:
- Check normalization doesn't over-normalize
- Verify pre-tokenization splits correctly
- Ensure decoding reconstructs text
3. **Preserve alignment for downstream tasks**:
- Use `trim_offsets` instead of stripping in normalizer
- Test `char_to_token()` on sample spans
4. **Document your pipeline**:
- Save complete tokenizer config
- Document special tokens
- Note any custom components
@@ -0,0 +1,565 @@
# Training Custom Tokenizers
Complete guide to training tokenizers from scratch.
## Training workflow
### Step 1: Choose tokenization algorithm
**Decision tree**:
- **GPT-style model** → BPE
- **BERT-style model** → WordPiece
- **Multilingual/No word boundaries** → Unigram
### Step 2: Prepare training data
```python
# Option 1: From files
files = ["train.txt", "validation.txt"]
# Option 2: From Python list
texts = [
"This is the first sentence.",
"This is the second sentence.",
# ... more texts
]
# Option 3: From dataset iterator
from datasets import load_dataset
dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="train")
def batch_iterator(batch_size=1000):
for i in range(0, len(dataset), batch_size):
yield dataset[i:i + batch_size]["text"]
```
### Step 3: Initialize tokenizer
**BPE example**:
```python
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
tokenizer = Tokenizer(BPE())
tokenizer.pre_tokenizer = ByteLevel()
tokenizer.decoder = ByteLevelDecoder()
trainer = BpeTrainer(
vocab_size=50000,
min_frequency=2,
special_tokens=["<|endoftext|>", "<|padding|>"],
show_progress=True
)
```
**WordPiece example**:
```python
from tokenizers.models import WordPiece
from tokenizers.trainers import WordPieceTrainer
from tokenizers.normalizers import BertNormalizer
from tokenizers.pre_tokenizers import BertPreTokenizer
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
tokenizer.normalizer = BertNormalizer(lowercase=True)
tokenizer.pre_tokenizer = BertPreTokenizer()
trainer = WordPieceTrainer(
vocab_size=30522,
min_frequency=2,
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
continuing_subword_prefix="##",
show_progress=True
)
```
**Unigram example**:
```python
from tokenizers.models import Unigram
from tokenizers.trainers import UnigramTrainer
tokenizer = Tokenizer(Unigram())
trainer = UnigramTrainer(
vocab_size=8000,
special_tokens=["<unk>", "<s>", "</s>", "<pad>"],
unk_token="<unk>",
show_progress=True
)
```
### Step 4: Train
```python
# From files
tokenizer.train(files=files, trainer=trainer)
# From iterator (recommended for large datasets)
tokenizer.train_from_iterator(
batch_iterator(),
trainer=trainer,
length=len(dataset) # Optional, for progress bar
)
```
**Training time** (30k vocab on 16-core CPU):
- 10 MB: 15-30 seconds
- 100 MB: 1-3 minutes
- 1 GB: 15-30 minutes
- 10 GB: 2-4 hours
### Step 5: Add post-processing
```python
from tokenizers.processors import TemplateProcessing
# BERT-style
tokenizer.post_processor = TemplateProcessing(
single="[CLS] $A [SEP]",
pair="[CLS] $A [SEP] $B [SEP]",
special_tokens=[
("[CLS]", tokenizer.token_to_id("[CLS]")),
("[SEP]", tokenizer.token_to_id("[SEP]")),
],
)
# GPT-2 style
tokenizer.post_processor = TemplateProcessing(
single="$A <|endoftext|>",
special_tokens=[
("<|endoftext|>", tokenizer.token_to_id("<|endoftext|>")),
],
)
```
### Step 6: Save
```python
# Save to JSON
tokenizer.save("my-tokenizer.json")
# Save to directory (for transformers)
tokenizer.save("my-tokenizer-dir/tokenizer.json")
# Convert to transformers format
from transformers import PreTrainedTokenizerFast
transformers_tokenizer = PreTrainedTokenizerFast(
tokenizer_object=tokenizer,
unk_token="[UNK]",
pad_token="[PAD]",
cls_token="[CLS]",
sep_token="[SEP]",
mask_token="[MASK]"
)
transformers_tokenizer.save_pretrained("my-tokenizer-dir")
```
## Trainer configuration
### BpeTrainer parameters
```python
from tokenizers.trainers import BpeTrainer
trainer = BpeTrainer(
vocab_size=30000, # Target vocabulary size
min_frequency=2, # Minimum frequency for merges
special_tokens=["[UNK]"], # Special tokens (added first)
limit_alphabet=1000, # Limit initial alphabet size
initial_alphabet=[], # Pre-defined initial characters
show_progress=True, # Show progress bar
continuing_subword_prefix="", # Prefix for continuing subwords
end_of_word_suffix="" # Suffix for end of words
)
```
**Parameter tuning**:
- **vocab_size**: Start with 30k for English, 50k for multilingual
- **min_frequency**: 2-5 for large corpora, 1 for small
- **limit_alphabet**: Reduce for non-English (CJK languages)
### WordPieceTrainer parameters
```python
from tokenizers.trainers import WordPieceTrainer
trainer = WordPieceTrainer(
vocab_size=30522, # BERT uses 30,522
min_frequency=2,
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
limit_alphabet=1000,
continuing_subword_prefix="##", # BERT-style prefix
show_progress=True
)
```
### UnigramTrainer parameters
```python
from tokenizers.trainers import UnigramTrainer
trainer = UnigramTrainer(
vocab_size=8000, # Typically smaller than BPE/WordPiece
special_tokens=["<unk>", "<s>", "</s>"],
unk_token="<unk>",
max_piece_length=16, # Maximum token length
n_sub_iterations=2, # EM algorithm iterations
shrinking_factor=0.75, # Vocabulary reduction rate
show_progress=True
)
```
## Training from large datasets
### Memory-efficient training
```python
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
# Load dataset
dataset = load_dataset("wikipedia", "20220301.en", split="train", streaming=True)
# Create iterator (yields batches)
def batch_iterator(batch_size=1000):
batch = []
for sample in dataset:
batch.append(sample["text"])
if len(batch) >= batch_size:
yield batch
batch = []
if batch:
yield batch
# Initialize tokenizer
tokenizer = Tokenizer(BPE())
trainer = BpeTrainer(vocab_size=50000, special_tokens=["<|endoftext|>"])
# Train (memory efficient - streams data)
tokenizer.train_from_iterator(
batch_iterator(),
trainer=trainer
)
```
**Memory usage**: ~200 MB (vs 10+ GB loading full dataset)
### Multi-file training
```python
import glob
# Find all training files
files = glob.glob("data/train/*.txt")
print(f"Training on {len(files)} files")
# Train on all files
tokenizer.train(files=files, trainer=trainer)
```
### Parallel training (multi-processing)
```python
from multiprocessing import Pool, cpu_count
import os
def train_shard(shard_files):
"""Train tokenizer on a shard of files."""
tokenizer = Tokenizer(BPE())
trainer = BpeTrainer(vocab_size=50000)
tokenizer.train(files=shard_files, trainer=trainer)
return tokenizer.get_vocab()
# Split files into shards
num_shards = cpu_count()
file_shards = [files[i::num_shards] for i in range(num_shards)]
# Train shards in parallel
with Pool(num_shards) as pool:
vocab_shards = pool.map(train_shard, file_shards)
# Merge vocabularies (custom logic needed)
# This is a simplified example - real implementation would merge intelligently
final_vocab = {}
for vocab in vocab_shards:
final_vocab.update(vocab)
```
## Domain-specific tokenizers
### Code tokenizer
```python
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.normalizers import Sequence, NFC
# Code-optimized configuration
tokenizer = Tokenizer(BPE())
# Minimal normalization (preserve case, whitespace)
tokenizer.normalizer = NFC() # Only normalize Unicode
# Byte-level pre-tokenization (handles all characters)
tokenizer.pre_tokenizer = ByteLevel()
# Train on code corpus
trainer = BpeTrainer(
vocab_size=50000,
special_tokens=["<|endoftext|>", "<|pad|>"],
min_frequency=2
)
tokenizer.train(files=["code_corpus.txt"], trainer=trainer)
```
### Medical/scientific tokenizer
```python
# Preserve case and special characters
from tokenizers.normalizers import NFKC
from tokenizers.pre_tokenizers import Whitespace, Punctuation, Sequence
tokenizer = Tokenizer(BPE())
# Minimal normalization
tokenizer.normalizer = NFKC()
# Preserve medical terms
tokenizer.pre_tokenizer = Sequence([
Whitespace(),
Punctuation(behavior="isolated") # Keep punctuation separate
])
trainer = BpeTrainer(
vocab_size=50000,
special_tokens=["[UNK]", "[CLS]", "[SEP]"],
min_frequency=3 # Higher threshold for rare medical terms
)
tokenizer.train(files=["pubmed_corpus.txt"], trainer=trainer)
```
### Multilingual tokenizer
```python
# Handle multiple scripts
from tokenizers.normalizers import NFKC, Lowercase, Sequence
tokenizer = Tokenizer(BPE())
# Normalize but don't lowercase (preserves script differences)
tokenizer.normalizer = NFKC()
# Byte-level handles all Unicode
from tokenizers.pre_tokenizers import ByteLevel
tokenizer.pre_tokenizer = ByteLevel()
trainer = BpeTrainer(
vocab_size=100000, # Larger vocab for multiple languages
special_tokens=["<unk>", "<s>", "</s>"],
limit_alphabet=None # No limit (handles all scripts)
)
# Train on multilingual corpus
tokenizer.train(files=["multilingual_corpus.txt"], trainer=trainer)
```
## Vocabulary size selection
### Guidelines by task
| Task | Recommended Vocab Size | Rationale |
|-----------------------|------------------------|-----------|
| English (monolingual) | 30,000 - 50,000 | Balanced coverage |
| Multilingual | 50,000 - 250,000 | More languages = more tokens |
| Code | 30,000 - 50,000 | Similar to English |
| Domain-specific | 10,000 - 30,000 | Smaller, focused vocabulary |
| Character-level tasks | 1,000 - 5,000 | Only characters + subwords |
### Vocabulary size impact
**Small vocab (10k)**:
- Pros: Faster training, smaller model, less memory
- Cons: More tokens per sentence, worse OOV handling
**Medium vocab (30k-50k)**:
- Pros: Good balance, standard choice
- Cons: None (recommended default)
**Large vocab (100k+)**:
- Pros: Fewer tokens per sentence, better OOV
- Cons: Slower training, larger embedding table
### Empirical testing
```python
# Train multiple tokenizers with different vocab sizes
vocab_sizes = [10000, 30000, 50000, 100000]
for vocab_size in vocab_sizes:
tokenizer = Tokenizer(BPE())
trainer = BpeTrainer(vocab_size=vocab_size)
tokenizer.train(files=["sample.txt"], trainer=trainer)
# Evaluate on test set
test_text = "Test sentence for evaluation..."
tokens = tokenizer.encode(test_text).ids
print(f"Vocab: {vocab_size:6d} | Tokens: {len(tokens):3d} | Avg: {len(test_text)/len(tokens):.2f} chars/token")
# Example output:
# Vocab: 10000 | Tokens: 12 | Avg: 2.33 chars/token
# Vocab: 30000 | Tokens: 8 | Avg: 3.50 chars/token
# Vocab: 50000 | Tokens: 7 | Avg: 4.00 chars/token
# Vocab: 100000 | Tokens: 6 | Avg: 4.67 chars/token
```
## Testing tokenizer quality
### Coverage test
```python
# Test on held-out data
test_corpus = load_dataset("wikitext", "wikitext-103-raw-v1", split="test")
total_tokens = 0
unk_tokens = 0
unk_id = tokenizer.token_to_id("[UNK]")
for text in test_corpus["text"]:
if text.strip():
encoding = tokenizer.encode(text)
total_tokens += len(encoding.ids)
unk_tokens += encoding.ids.count(unk_id)
unk_rate = unk_tokens / total_tokens
print(f"Unknown token rate: {unk_rate:.2%}")
# Good quality: <1% unknown tokens
# Acceptable: 1-5%
# Poor: >5%
```
### Compression test
```python
# Measure tokenization efficiency
import numpy as np
token_lengths = []
for text in test_corpus["text"][:1000]:
if text.strip():
encoding = tokenizer.encode(text)
chars_per_token = len(text) / len(encoding.ids)
token_lengths.append(chars_per_token)
avg_chars_per_token = np.mean(token_lengths)
print(f"Average characters per token: {avg_chars_per_token:.2f}")
# Good: 4-6 chars/token (English)
# Acceptable: 3-4 chars/token
# Poor: <3 chars/token (under-compression)
```
### Semantic test
```python
# Manually inspect tokenization of common words/phrases
test_phrases = [
"tokenization",
"machine learning",
"artificial intelligence",
"preprocessing",
"hello world"
]
for phrase in test_phrases:
tokens = tokenizer.encode(phrase).tokens
print(f"{phrase:25s}{tokens}")
# Good tokenization:
# tokenization → ['token', 'ization']
# machine learning → ['machine', 'learning']
# artificial intelligence → ['artificial', 'intelligence']
```
## Troubleshooting
### Issue: Training too slow
**Solutions**:
1. Reduce vocabulary size
2. Increase `min_frequency`
3. Use `limit_alphabet` to reduce initial alphabet
4. Train on subset first
```python
# Fast training configuration
trainer = BpeTrainer(
vocab_size=20000, # Smaller vocab
min_frequency=5, # Higher threshold
limit_alphabet=500, # Limit alphabet
show_progress=True
)
```
### Issue: High unknown token rate
**Solutions**:
1. Increase vocabulary size
2. Decrease `min_frequency`
3. Check normalization (might be too aggressive)
```python
# Better coverage configuration
trainer = BpeTrainer(
vocab_size=50000, # Larger vocab
min_frequency=1, # Lower threshold
)
```
### Issue: Poor quality tokenization
**Solutions**:
1. Verify normalization matches your use case
2. Check pre-tokenization splits correctly
3. Ensure training data is representative
4. Try different algorithm (BPE vs WordPiece vs Unigram)
```python
# Debug tokenization pipeline
text = "Sample text to debug"
# Check normalization
normalized = tokenizer.normalizer.normalize_str(text)
print(f"Normalized: {normalized}")
# Check pre-tokenization
pre_tokens = tokenizer.pre_tokenizer.pre_tokenize_str(text)
print(f"Pre-tokens: {pre_tokens}")
# Check final tokenization
tokens = tokenizer.encode(text).tokens
print(f"Tokens: {tokens}")
```
## Best practices
1. **Use representative training data** - Match your target domain
2. **Start with standard configs** - BERT WordPiece or GPT-2 BPE
3. **Test on held-out data** - Measure unknown token rate
4. **Iterate on vocabulary size** - Test 30k, 50k, 100k
5. **Save tokenizer with model** - Ensure reproducibility
6. **Version your tokenizers** - Track changes for reproducibility
7. **Document special tokens** - Critical for model training
@@ -0,0 +1,235 @@
---
name: sentencepiece
description: Language-independent tokenizer treating text as raw Unicode. Supports BPE and Unigram algorithms. Fast (50k sentences/sec), lightweight (6MB memory), deterministic vocabulary. Used by T5, ALBERT, XLNet, mBART. Train on raw text without pre-tokenization. Use when you need multilingual support, CJK languages, or reproducible tokenization.
version: 1.0.0
author: Orchestra Research
license: MIT
tags: [Tokenization, SentencePiece, Language-Independent, BPE, Unigram, Multilingual, CJK Languages, Unicode, Deterministic, Google]
dependencies: [sentencepiece, transformers]
---
# SentencePiece - Language-Independent Tokenization
Unsupervised tokenizer that works on raw text without language-specific preprocessing.
## When to use SentencePiece
**Use SentencePiece when:**
- Building multilingual models (no language-specific rules)
- Working with CJK languages (Chinese, Japanese, Korean)
- Need reproducible tokenization (deterministic vocabulary)
- Want to train on raw text (no pre-tokenization needed)
- Require lightweight deployment (6MB memory, 50k sentences/sec)
**Performance**:
- **Speed**: 50,000 sentences/sec
- **Memory**: ~6MB for loaded model
- **Languages**: All (language-independent)
**Use alternatives instead**:
- **HuggingFace Tokenizers**: Faster training, more flexibility
- **tiktoken**: OpenAI models (GPT-3.5/4)
- **BERT WordPiece**: English-centric tasks
## Quick start
### Installation
```bash
# Python
pip install sentencepiece
# C++ (requires CMake)
git clone https://github.com/google/sentencepiece.git
cd sentencepiece
mkdir build && cd build
cmake .. && make -j $(nproc)
sudo make install
```
### Train model
```bash
# Command-line (BPE with 8000 vocab)
spm_train --input=data.txt --model_prefix=m --vocab_size=8000 --model_type=bpe
# Python API
import sentencepiece as spm
spm.SentencePieceTrainer.train(
input='data.txt',
model_prefix='m',
vocab_size=8000,
model_type='bpe'
)
```
**Training time**: ~1-2 minutes for 100MB corpus
### Encode and decode
```python
import sentencepiece as spm
# Load model
sp = spm.SentencePieceProcessor(model_file='m.model')
# Encode to pieces
pieces = sp.encode('This is a test', out_type=str)
print(pieces) # ['▁This', '▁is', '▁a', '▁test']
# Encode to IDs
ids = sp.encode('This is a test', out_type=int)
print(ids) # [284, 47, 11, 1243]
# Decode
text = sp.decode(ids)
print(text) # "This is a test"
```
## Language-independent design
### Whitespace as symbol (▁)
```python
text = "Hello world"
pieces = sp.encode(text, out_type=str)
print(pieces) # ['▁Hello', '▁world']
# Decode preserves spaces
decoded = sp.decode_pieces(pieces)
print(decoded) # "Hello world"
```
**Key principle**: Treat text as raw Unicode, whitespace = ▁ (meta symbol)
## Tokenization algorithms
### BPE (Byte-Pair Encoding)
```python
spm.SentencePieceTrainer.train(
input='data.txt',
model_prefix='bpe_model',
vocab_size=16000,
model_type='bpe'
)
```
**Used by**: mBART
### Unigram (default)
```python
spm.SentencePieceTrainer.train(
input='data.txt',
model_prefix='unigram_model',
vocab_size=8000,
model_type='unigram'
)
```
**Used by**: T5, ALBERT, XLNet
## Training configuration
### Essential parameters
```python
spm.SentencePieceTrainer.train(
input='corpus.txt',
model_prefix='m',
vocab_size=32000,
model_type='unigram',
character_coverage=0.9995, # 1.0 for CJK
user_defined_symbols=['[SEP]', '[CLS]'],
unk_piece='<unk>',
num_threads=16
)
```
### Character coverage
| Language Type | Coverage | Rationale |
|---------------|----------|-----------|
| English | 0.9995 | Most common chars |
| CJK (Chinese) | 1.0 | All characters needed |
| Multilingual | 0.9995 | Balance |
## Encoding options
### Subword regularization
```python
# Sample different tokenizations
for _ in range(3):
pieces = sp.encode('tokenization', out_type=str, enable_sampling=True, alpha=0.1)
print(pieces)
# Output (different each time):
# ['▁token', 'ization']
# ['▁tok', 'en', 'ization']
```
**Use case**: Data augmentation for robustness.
## Common patterns
### T5-style training
```python
spm.SentencePieceTrainer.train(
input='c4_corpus.txt',
model_prefix='t5',
vocab_size=32000,
model_type='unigram',
user_defined_symbols=[f'<extra_id_{i}>' for i in range(100)],
unk_id=2,
eos_id=1,
pad_id=0
)
```
### Integration with transformers
```python
from transformers import T5Tokenizer
# T5 uses SentencePiece internally
tokenizer = T5Tokenizer.from_pretrained('t5-base')
inputs = tokenizer('translate English to French: Hello', return_tensors='pt')
```
## Performance benchmarks
### Training speed
| Corpus | BPE (16k) | Unigram (8k) |
|--------|-----------|--------------|
| 100 MB | 1-2 min | 3-4 min |
| 1 GB | 10-15 min | 30-40 min |
### Tokenization speed
- **SentencePiece**: 50,000 sentences/sec
- **HF Tokenizers**: 200,000 sentences/sec (4× faster)
## Supported models
**T5 family**: `t5-base`, `t5-large` (32k vocab, Unigram)
**ALBERT**: `albert-base-v2` (30k vocab, Unigram)
**XLNet**: `xlnet-base-cased` (32k vocab, Unigram)
**mBART**: `facebook/mbart-large-50` (250k vocab, BPE)
## References
- **[Training Guide](references/training.md)** - Detailed options, corpus preparation
- **[Algorithms](references/algorithms.md)** - BPE vs Unigram, subword regularization
## Resources
- **GitHub**: https://github.com/google/sentencepiece ⭐ 10,000+
- **Paper**: https://arxiv.org/abs/1808.06226 (EMNLP 2018)
- **Version**: 0.2.0+
@@ -0,0 +1,200 @@
# Tokenization Algorithms
BPE vs Unigram comparison and subword regularization.
## BPE (Byte-Pair Encoding)
### Algorithm
1. Initialize vocabulary with characters
2. Count frequency of adjacent token pairs
3. Merge most frequent pair
4. Repeat until vocabulary size reached
### Example
**Corpus**:
```
low: 5
lower: 2
newest: 6
widest: 3
```
**Iteration 1**:
- Most frequent pair: 'e' + 's' (9 times)
- Merge → 'es'
- Vocabulary: [chars] + ['es']
**Iteration 2**:
- Most frequent: 'es' + 't' (9 times)
- Merge → 'est'
- Vocabulary: [chars] + ['es', 'est']
**Result**: `newest``new|est`, `widest``wid|est`
### Implementation
```python
import sentencepiece as spm
spm.SentencePieceTrainer.train(
input='corpus.txt',
model_type='bpe',
vocab_size=16000
)
```
### Advantages
- Simple algorithm
- Fast training
- Good compression ratio
### Disadvantages
- Deterministic (no sampling)
- May split common words unexpectedly
## Unigram
### Algorithm
1. Start with large vocabulary (all substrings)
2. Compute probability of each token
3. Remove tokens with minimal loss impact
4. Repeat until vocabulary size reached
### Probabilistic tokenization
Given vocabulary with probabilities:
```
P('low') = 0.02
P('est') = 0.03
P('l') = 0.01
P('o') = 0.015
...
```
Tokenize "lowest":
```
Option 1: ['low', 'est']
P = 0.02 × 0.03 = 0.0006 ← highest
Option 2: ['l', 'o', 'w', 'est']
P = 0.01 × 0.015 × 0.01 × 0.03 = 0.000000045
Choose option 1 (highest probability)
```
### Implementation
```python
spm.SentencePieceTrainer.train(
input='corpus.txt',
model_type='unigram',
vocab_size=8000
)
```
### Advantages
- Probabilistic (can sample)
- Better for morphologically rich languages
- Supports subword regularization
### Disadvantages
- Slower training
- More complex algorithm
## Comparison
| Feature | BPE | Unigram |
|---------|-----|---------|
| Training speed | Fast | Slow |
| Tokenization | Deterministic | Probabilistic |
| Sampling | No | Yes |
| Typical vocab size | 16k-32k | 8k-32k |
| Used by | mBART | T5, ALBERT, XLNet |
## Subword regularization
Sample different tokenizations during training for robustness.
### Enable sampling
```python
sp = spm.SentencePieceProcessor(model_file='m.model')
# Sample different tokenizations
for _ in range(5):
pieces = sp.encode('tokenization', out_type=str, enable_sampling=True, alpha=0.1)
print(pieces)
# Output (different each time):
# ['▁token', 'ization']
# ['▁tok', 'en', 'ization']
# ['▁token', 'iz', 'ation']
# ['▁to', 'ken', 'ization']
# ['▁token', 'ization']
```
### Parameters
- `alpha`: Regularization strength
- 0.0 = deterministic (no sampling)
- 0.1 = slight variation
- 0.5 = high variation
- 1.0 = maximum variation
### Benefits
1. **Robustness**: Model learns multiple tokenizations
2. **Data augmentation**: More diverse training data
3. **Better generalization**: Less overfitting to specific tokenization
### Use case
```python
# Training loop with regularization
for batch in dataloader:
# Sample different tokenizations each epoch
tokens = sp.encode(batch['text'], enable_sampling=True, alpha=0.1)
# Train model...
```
**Used by**: mT5, XLM-RoBERTa
## NBest encoding
Get multiple tokenization candidates with scores.
```python
sp = spm.SentencePieceProcessor(model_file='m.model')
# Get top-5 tokenizations
nbest = sp.nbest_encode('tokenization', nbest_size=5, out_type=str)
for pieces, score in nbest:
print(f"{pieces} (log prob: {score:.4f})")
# Output:
# ['▁token', 'ization'] (log prob: -2.34)
# ['▁tok', 'en', 'ization'] (log prob: -2.41)
# ['▁token', 'iz', 'ation'] (log prob: -2.57)
```
### Use cases
1. **Ensemble tokenization**: Average over multiple tokenizations
2. **Uncertainty estimation**: Check variance in scores
3. **Debugging**: Understand tokenizer behavior
## Best practices
1. **Use Unigram for multilingual** - Better for diverse languages
2. **Use BPE for speed** - Faster training and inference
3. **Enable subword regularization** - Improves model robustness
4. **Set alpha=0.1 for slight variation** - Good balance
5. **Use deterministic mode for inference** - Consistent results
@@ -0,0 +1,304 @@
# SentencePiece Training Guide
Complete guide to training SentencePiece models.
## Training workflow
### Step 1: Prepare corpus
```bash
# Plain text file, one sentence per line (recommended)
cat corpus.txt
# Hello world
# This is a test
# SentencePiece is language-independent
# Or use raw text (SentencePiece handles sentence splitting)
```
### Step 2: Train model
**Command-line**:
```bash
spm_train \
--input=corpus.txt \
--model_prefix=m \
--vocab_size=8000 \
--model_type=unigram \
--character_coverage=0.9995
```
**Python API**:
```python
import sentencepiece as spm
spm.SentencePieceTrainer.train(
input='corpus.txt',
model_prefix='m',
vocab_size=8000,
model_type='unigram'
)
```
**Output**: `m.model` (binary), `m.vocab` (text vocabulary)
### Step 3: Load and use
```python
sp = spm.SentencePieceProcessor(model_file='m.model')
pieces = sp.encode('Test sentence', out_type=str)
```
## Training parameters
### Core parameters
```python
spm.SentencePieceTrainer.train(
# Required
input='corpus.txt', # Input corpus
model_prefix='output', # Output prefix
vocab_size=8000, # Target vocabulary size
# Algorithm
model_type='unigram', # 'unigram', 'bpe', 'char', 'word'
# Coverage
character_coverage=0.9995, # 0.9995 for most, 1.0 for CJK
# Normalization
normalization_rule_name='nmt_nfkc', # 'nmt_nfkc', 'nfkc', 'identity'
# Performance
num_threads=16, # Training threads
input_sentence_size=10000000 # Max sentences to load
)
```
### Special tokens
```python
spm.SentencePieceTrainer.train(
input='corpus.txt',
model_prefix='m',
vocab_size=32000,
# Control symbols (special tokens for model control)
control_symbols=['<s>', '</s>', '<pad>'],
# User-defined symbols (never split)
user_defined_symbols=['[MASK]', '[SEP]', '[CLS]'],
# Special token pieces
unk_piece='<unk>',
bos_piece='<s>',
eos_piece='</s>',
pad_piece='<pad>',
# Special token IDs
unk_id=0,
bos_id=1,
eos_id=2,
pad_id=3
)
```
### Advanced options
```python
spm.SentencePieceTrainer.train(
input='corpus.txt',
model_prefix='m',
vocab_size=32000,
# Byte fallback (handle unknown chars)
byte_fallback=True,
# Digit handling
split_digits=True, # Split digits individually
# Script splitting
split_by_unicode_script=True, # Split by Unicode script
split_by_whitespace=True, # Split by whitespace
# Length constraints
max_sentencepiece_length=16, # Max token length
# Rare word handling
min_frequency=2, # Min frequency for token
# Training size
input_sentence_size=10000000, # Max sentences
shuffle_input_sentence=True, # Shuffle training data
# Seed
seed_sentencepiece_size=1000000 # Seed vocab size
)
```
## Training from Python iterator
```python
import sentencepiece as spm
from datasets import load_dataset
# Load dataset
dataset = load_dataset('wikitext', 'wikitext-103-raw-v1', split='train')
# Create iterator
def corpus_iterator():
for example in dataset:
if example['text'].strip():
yield example['text']
# Train from iterator
spm.SentencePieceTrainer.train(
sentence_iterator=corpus_iterator(),
model_prefix='wiki',
vocab_size=32000,
model_type='unigram'
)
```
## Model types
### BPE
```python
spm.SentencePieceTrainer.train(
input='corpus.txt',
model_type='bpe',
vocab_size=16000
)
```
**Training time**: ~10-15 min for 1GB corpus
### Unigram (recommended)
```python
spm.SentencePieceTrainer.train(
input='corpus.txt',
model_type='unigram',
vocab_size=8000
)
```
**Training time**: ~30-40 min for 1GB corpus
## Character coverage
### English/European (0.9995)
```python
spm.SentencePieceTrainer.train(
input='en_corpus.txt',
character_coverage=0.9995 # Cover 99.95% of chars
)
```
Covers: a-z, A-Z, punctuation, common accents
### CJK (1.0)
```python
spm.SentencePieceTrainer.train(
input='zh_corpus.txt',
character_coverage=1.0 # Cover ALL characters
)
```
Required for: Chinese, Japanese, Korean
### Multilingual (0.9995-1.0)
```python
spm.SentencePieceTrainer.train(
input='multilingual_corpus.txt',
character_coverage=0.9995 # Balance coverage/size
)
```
## Vocabulary size selection
| Task | Vocab Size | Rationale |
|------|------------|-----------|
| English monolingual | 16k-32k | Standard |
| Multilingual | 32k-250k | More languages |
| CJK | 32k-100k | More characters |
| Code | 16k-32k | Similar to English |
## Normalization rules
### nmt_nfkc (recommended)
```python
normalization_rule_name='nmt_nfkc'
```
- NFKC Unicode normalization
- Whitespace handling
- **Recommended for most tasks**
### identity (no normalization)
```python
normalization_rule_name='identity'
```
- Preserves input exactly
- Use for code, case-sensitive tasks
### nfkc (standard Unicode)
```python
normalization_rule_name='nfkc'
```
- Standard Unicode normalization
- Less aggressive than nmt_nfkc
## Performance optimization
### Multi-threading
```python
spm.SentencePieceTrainer.train(
input='large_corpus.txt',
num_threads=32 # Use all cores
)
```
**Speedup**: ~4-8× with 16+ cores
### Sampling input
```python
spm.SentencePieceTrainer.train(
input='huge_corpus.txt',
input_sentence_size=10000000, # Sample 10M sentences
shuffle_input_sentence=True
)
```
**For very large corpora** (>10GB)
### Extremely large corpus
```python
spm.SentencePieceTrainer.train(
input='massive_corpus.txt',
train_extremely_large_corpus=True, # Enable for >10GB
input_sentence_size=100000000
)
```
## Best practices
1. **Use Unigram for most tasks** - Better for multilingual
2. **Set character_coverage=1.0 for CJK** - Required for full coverage
3. **Use nmt_nfkc normalization** - Works well for most cases
4. **Add user_defined_symbols for special tokens** - BERT-style tokens
5. **Enable byte_fallback for robustness** - Handles emojis/rare chars
6. **Start with vocab_size=32000** - Good default for most tasks
7. **Use multi-threading** - Speeds up training significantly
@@ -0,0 +1,158 @@
---
name: axolotl
description: Expert guidance for fine-tuning LLMs with Axolotl - YAML configs, 100+ models, LoRA/QLoRA, DPO/KTO/ORPO/GRPO, multimodal support
version: 1.0.0
author: Orchestra Research
license: MIT
tags: [Fine-Tuning, Axolotl, LLM, LoRA, QLoRA, DPO, KTO, ORPO, GRPO, YAML, HuggingFace, DeepSpeed, Multimodal]
dependencies: [axolotl, torch, transformers, datasets, peft, accelerate, deepspeed]
---
# Axolotl Skill
Comprehensive assistance with axolotl development, generated from official documentation.
## When to Use This Skill
This skill should be triggered when:
- Working with axolotl
- Asking about axolotl features or APIs
- Implementing axolotl solutions
- Debugging axolotl code
- Learning axolotl best practices
## Quick Reference
### Common Patterns
**Pattern 1:** To validate that acceptable data transfer speeds exist for your training job, running NCCL Tests can help pinpoint bottlenecks, for example:
```
./build/all_reduce_perf -b 8 -e 128M -f 2 -g 3
```
**Pattern 2:** Configure your model to use FSDP in the Axolotl yaml. For example:
```
fsdp_version: 2
fsdp_config:
offload_params: true
state_dict_type: FULL_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: LlamaDecoderLayer
reshard_after_forward: true
```
**Pattern 3:** The context_parallel_size should be a divisor of the total number of GPUs. For example:
```
context_parallel_size
```
**Pattern 4:** For example: - With 8 GPUs and no sequence parallelism: 8 different batches processed per step - With 8 GPUs and context_parallel_size=4: Only 2 different batches processed per step (each split across 4 GPUs) - If your per-GPU micro_batch_size is 2, the global batch size decreases from 16 to 4
```
context_parallel_size=4
```
**Pattern 5:** Setting save_compressed: true in your configuration enables saving models in a compressed format, which: - Reduces disk space usage by approximately 40% - Maintains compatibility with vLLM for accelerated inference - Maintains compatibility with llmcompressor for further optimization (example: quantization)
```
save_compressed: true
```
**Pattern 6:** Note It is not necessary to place your integration in the integrations folder. It can be in any location, so long as its installed in a package in your python env. See this repo for an example: https://github.com/axolotl-ai-cloud/diff-transformer
```
integrations
```
**Pattern 7:** Handle both single-example and batched data. - single example: sample[input_ids] is a list[int] - batched data: sample[input_ids] is a list[list[int]]
```
utils.trainer.drop_long_seq(sample, sequence_len=2048, min_sequence_len=2)
```
### Example Code Patterns
**Example 1** (python):
```python
cli.cloud.modal_.ModalCloud(config, app=None)
```
**Example 2** (python):
```python
cli.cloud.modal_.run_cmd(cmd, run_folder, volumes=None)
```
**Example 3** (python):
```python
core.trainers.base.AxolotlTrainer(
*_args,
bench_data_collator=None,
eval_data_collator=None,
dataset_tags=None,
**kwargs,
)
```
**Example 4** (python):
```python
core.trainers.base.AxolotlTrainer.log(logs, start_time=None)
```
**Example 5** (python):
```python
prompt_strategies.input_output.RawInputOutputPrompter()
```
## Reference Files
This skill includes comprehensive documentation in `references/`:
- **api.md** - Api documentation
- **dataset-formats.md** - Dataset-Formats documentation
- **other.md** - Other documentation
Use `view` to read specific reference files when detailed information is needed.
## Working with This Skill
### For Beginners
Start with the getting_started or tutorials reference files for foundational concepts.
### For Specific Features
Use the appropriate category reference file (api, guides, etc.) for detailed information.
### For Code Examples
The quick reference section above contains common patterns extracted from the official docs.
## Resources
### references/
Organized documentation extracted from official sources. These files contain:
- Detailed explanations
- Code examples with language annotations
- Links to original documentation
- Table of contents for quick navigation
### scripts/
Add helper scripts here for common automation tasks.
### assets/
Add templates, boilerplate, or example projects here.
## Notes
- This skill was automatically generated from official documentation
- Reference files preserve the structure and examples from source docs
- Code examples include language detection for better syntax highlighting
- Quick reference patterns are extracted from common usage examples in the docs
## Updating
To refresh this skill with updated documentation:
1. Re-run the scraper with the same configuration
2. The skill will be rebuilt with the latest information
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,15 @@
# Axolotl Documentation Index
## Categories
### Api
**File:** `api.md`
**Pages:** 150
### Dataset-Formats
**File:** `dataset-formats.md`
**Pages:** 9
### Other
**File:** `other.md`
**Pages:** 26
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,80 @@
---
name: llama-factory
description: Expert guidance for fine-tuning LLMs with LLaMA-Factory - WebUI no-code, 100+ models, 2/3/4/5/6/8-bit QLoRA, multimodal support
version: 1.0.0
author: Orchestra Research
license: MIT
tags: [Fine-Tuning, LLaMA Factory, LLM, WebUI, No-Code, QLoRA, LoRA, Multimodal, HuggingFace, Llama, Qwen, Gemma]
dependencies: [llmtuner, torch, transformers, datasets, peft, accelerate, gradio]
---
# Llama-Factory Skill
Comprehensive assistance with llama-factory development, generated from official documentation.
## When to Use This Skill
This skill should be triggered when:
- Working with llama-factory
- Asking about llama-factory features or APIs
- Implementing llama-factory solutions
- Debugging llama-factory code
- Learning llama-factory best practices
## Quick Reference
### Common Patterns
*Quick reference patterns will be added as you use the skill.*
## Reference Files
This skill includes comprehensive documentation in `references/`:
- **_images.md** - Images documentation
- **advanced.md** - Advanced documentation
- **getting_started.md** - Getting Started documentation
- **other.md** - Other documentation
Use `view` to read specific reference files when detailed information is needed.
## Working with This Skill
### For Beginners
Start with the getting_started or tutorials reference files for foundational concepts.
### For Specific Features
Use the appropriate category reference file (api, guides, etc.) for detailed information.
### For Code Examples
The quick reference section above contains common patterns extracted from the official docs.
## Resources
### references/
Organized documentation extracted from official sources. These files contain:
- Detailed explanations
- Code examples with language annotations
- Links to original documentation
- Table of contents for quick navigation
### scripts/
Add helper scripts here for common automation tasks.
### assets/
Add templates, boilerplate, or example projects here.
## Notes
- This skill was automatically generated from official documentation
- Reference files preserve the structure and examples from source docs
- Code examples include language detection for better syntax highlighting
- Quick reference patterns are extracted from common usage examples in the docs
## Updating
To refresh this skill with updated documentation:
1. Re-run the scraper with the same configuration
2. The skill will be rebuilt with the latest information
@@ -0,0 +1,23 @@
# Llama-Factory - Images
**Pages:** 3
---
##
**URL:** https://llamafactory.readthedocs.io/en/latest/_images/logo.png
---
##
**URL:** https://llamafactory.readthedocs.io/en/latest/_images/quantization_0.png
---
##
**URL:** https://llamafactory.readthedocs.io/en/latest/_images/webui_0.png
---
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,349 @@
# Llama-Factory - Getting Started
**Pages:** 7
---
## Installation¶
**URL:** https://llamafactory.readthedocs.io/en/latest/getting_started/installation.html
**Contents:**
- Installation¶
- Linux¶
- CUDA 安装¶
- Windows¶
- CUDA 安装¶
- LLaMA-Factory 安装¶
- LLaMA-Factory 校验¶
- LLaMA-Factory 高级选项¶
- Windows¶
- QLoRA¶
CUDA 是由 NVIDIA 创建的一个并行计算平台和编程模型,它让开发者可以使用 NVIDIA 的 GPU 进行高性能的并行计算。
首先,在 https://developer.nvidia.com/cuda-gpus 查看您的 GPU 是否支持CUDA
保证当前 Linux 版本支持CUDA. 在命令行中输入 uname -m && cat /etc/*release,应当看到类似的输出
检查是否安装了 gcc . 在命令行中输入 gcc --version ,应当看到类似的输出
在以下网址下载所需的 CUDA,这里推荐12.2版本。 https://developer.nvidia.com/cuda-gpus 注意需要根据上述输出选择正确版本
如果您之前安装过 CUDA(例如为12.1版本),需要先使用 sudo /usr/local/cuda-12.1/bin/cuda-uninstaller 卸载。如果该命令无法运行,可以直接:
卸载完成后运行以下命令并根据提示继续安装:
注意:在确定 CUDA 自带驱动版本与 GPU 是否兼容之前,建议取消 Driver 的安装。
完成后输入 nvcc -V 检查是否出现对应的版本号,若出现则安装完成。
打开 设置 ,在 关于 中找到 Windows 规格 保证系统版本在以下列表中:
Microsoft Windows 11 21H2
Microsoft Windows 11 22H2-SV2
Microsoft Windows 11 23H2
Microsoft Windows 10 21H2
Microsoft Windows 10 22H2
Microsoft Windows Server 2022
打开 cmd 输入 nvcc -V ,若出现类似内容则安装成功。
否则,检查系统环境变量,保证 CUDA 被正确导入。
在安装 LLaMA-Factory 之前,请确保您安装了下列依赖:
运行以下指令以安装 LLaMA-Factory 及其依赖:
如果出现环境冲突,请尝试使用 pip install --no-deps -e . 解决
完成安装后,可以通过使用 llamafactory-cli version 来快速校验安装是否成功
如果您能成功看到类似下面的界面,就说明安装成功了。
如果您想在 Windows 上启用量化 LoRAQLoRA),请根据您的 CUDA 版本选择适当的 bitsandbytes 发行版本。
如果您要在 Windows 平台上启用 FlashAttention-2,请根据您的 CUDA 版本选择适当的 flash-attention 发行版本。
开源深度学习框架 PyTorch,广泛用于机器学习和人工智能研究中。
提供了加载 Qwen v1 模型所需的包。
魔搭社区,提供了预训练模型和数据集的下载途径。
开源训练跟踪工具 SwanLab,用于记录与可视化训练过程
用于 LLaMA Factory 开发维护。
---
## WebUI¶
**URL:** https://llamafactory.readthedocs.io/en/latest/getting_started/webui.html
**Contents:**
- WebUI¶
- 训练¶
- 评估预测与对话¶
- 导出¶
LLaMA-Factory 支持通过 WebUI 零代码微调大语言模型。 在完成 安装 后,您可以通过以下指令进入 WebUI:
WebUI 主要分为四个界面:训练、评估与预测、对话、导出。
随后,您可以点击 开始 按钮开始训练模型。
关于断点重连:适配器断点保存于 output_dir 目录下,请指定 适配器路径 以加载断点继续训练。
如果您需要使用自定义数据集,请在 data/data_info.json 中添加自定义数据集描述并确保 数据集格式 正确,否则可能会导致训练失败。
模型训练完毕后,您可以通过在评估与预测界面通过指定 模型 及 适配器 的路径在指定数据集上进行评估。
您也可以通过在对话界面指定 模型、 适配器 及 推理引擎 后输入对话内容与模型进行对话观察效果。
如果您对模型效果满意并需要导出模型,您可以在导出界面通过指定 模型、 适配器、 分块大小、 导出量化等级及校准数据集、 导出设备、 导出目录 等参数后点击 导出 按钮导出模型。
---
## Merge¶
**URL:** https://llamafactory.readthedocs.io/en/latest/getting_started/merge_lora.html
**Contents:**
- Merge¶
- 合并¶
- 量化¶
当我们基于预训练模型训练好 LoRA 适配器后,我们不希望在每次推理的时候分别加载预训练模型和 LoRA 适配器,因此我们需要将预训练模型和 LoRA 适配器合并导出成一个模型,并根据需要选择是否量化。根据是否量化以及量化算法的不同,导出的配置文件也有所区别。
您可以通过 llamafactory-cli export merge_config.yaml 指令来合并模型。其中 merge_config.yaml 需要您根据不同情况进行配置。
examples/merge_lora/llama3_lora_sft.yaml 提供了合并时的配置示例。
模型 model_name_or_path 需要存在且与 template 相对应。 adapter_name_or_path 需要与微调中的适配器输出路径 output_dir 相对应。
合并 LoRA 适配器时,不要使用量化模型或指定量化位数。您可以使用本地或下载的未量化的预训练模型进行合并。
在完成模型合并并获得完整模型后,为了优化部署效果,人们通常会基于显存占用、使用成本和推理速度等因素,选择通过量化技术对模型进行压缩,从而实现更高效的部署。
量化(Quantization)通过数据精度压缩有效地减少了显存使用并加速推理。LLaMA-Factory 支持多种量化方法,包括:
GPTQ 等后训练量化方法(Post Training Quantization)是一种在训练后对预训练模型进行量化的方法。我们通过量化技术将高精度表示的预训练模型转换为低精度的模型,从而在避免过多损失模型性能的情况下减少显存占用并加速推理,我们希望低精度数据类型在有限的表示范围内尽可能地接近高精度数据类型的表示,因此我们需要指定量化位数 export_quantization_bit 以及校准数据集 export_quantization_dataset。
model_name_or_path: 预训练模型的名称或路径
export_quantization_bit: 量化位数
export_quantization_dataset: 量化校准数据集
export_size: 最大导出模型文件大小
export_legacy_format: 是否使用旧格式导出
QLoRA 是一种在 4-bit 量化模型基础上使用 LoRA 方法进行训练的技术。它在极大地保持了模型性能的同时大幅减少了显存占用和推理时间。
不要使用量化模型或设置量化位数 quantization_bit
---
## Inference¶
**URL:** https://llamafactory.readthedocs.io/en/latest/getting_started/inference.html
**Contents:**
- Inference¶
- 原始模型推理配置¶
- 微调模型推理配置¶
- 多模态模型¶
- 批量推理¶
- 数据集¶
- api¶
LLaMA-Factory 支持多种推理方式。
您可以使用 llamafactory-cli chat inference_config.yaml 或 llamafactory-cli webchat inference_config.yaml 进行推理与模型对话。对话时配置文件只需指定原始模型 model_name_or_path 和 template ,并根据是否是微调模型指定 adapter_name_or_path 和 finetuning_type。
如果您希望向模型输入大量数据集并保存推理结果,您可以启动 vllm 推理引擎对大量数据集进行快速的批量推理。您也可以通过 部署 api 服务的形式通过 api 调用来进行批量推理。
默认情况下,模型推理将使用 Huggingface 引擎。 您也可以指定 infer_backend: vllm 以使用 vllm 推理引擎以获得更快的推理速度。
使用任何方式推理时,模型 model_name_or_path 需要存在且与 template 相对应。
对于原始模型推理, inference_config.yaml 中 只需指定原始模型 model_name_or_path 和 template 即可。
对于微调模型推理,除原始模型和模板外,还需要指定适配器路径 adapter_name_or_path 和微调类型 finetuning_type。
对于多模态模型,您可以运行以下指令进行推理。
examples/inference/llava1_5.yaml 的配置示例如下:
您可以通过以下指令启动 vllm 推理引擎并使用数据集进行批量推理:
如果您需要使用 api 进行批量推理,您只需指定模型、适配器(可选)、模板、微调方式等信息。
下面是一个启动并调用 api 服务的示例:
您可以使用 API_PORT=8000 CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.yaml 启动 api 服务并运行以下示例程序进行调用:
---
## Eval¶
**URL:** https://llamafactory.readthedocs.io/en/latest/getting_started/eval.html
**Contents:**
- Eval¶
- 通用能力评估¶
- NLG 评估¶
- 评估相关参数¶
在完成模型训练后,您可以通过 llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml 来评估模型效果。
配置示例文件 examples/train_lora/llama3_lora_eval.yaml 具体如下:
此外,您还可以通过 llamafactory-cli train examples/extras/nlg_eval/llama3_lora_predict.yaml 来获得模型的 BLEU 和 ROUGE 分数以评价模型生成质量。
配置示例文件 examples/extras/nlg_eval/llama3_lora_predict.yaml 具体如下:
同样,您也通过在指令 python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo 中指定模型、数据集以使用 vllm 推理框架以取得更快的推理速度。
评估任务的名称,可选项有 mmlu_test, ceval_validation, cmmlu_test
包含评估数据集的文件夹路径,默认值为 evaluation。
用于数据加载器的随机种子,默认值为 42。
评估使用的语言,可选值为 en、 zh。默认值为 en。
few-shot 的示例数量,默认值为 5。
保存评估结果的路径,默认值为 None。 如果该路径已经存在则会抛出错误。
评估数据集的下载模式,默认值为 DownloadMode.REUSE_DATASET_IF_EXISTS。如果数据集已经存在则重复使用,否则则下载。
---
## Data Preparation¶
**URL:** https://llamafactory.readthedocs.io/en/latest/getting_started/data_preparation.html
**Contents:**
- Data Preparation¶
- Alpaca¶
- 指令监督微调数据集¶
- 预训练数据集¶
- 偏好数据集¶
- KTO 数据集¶
- 多模态数据集¶
- 图像数据集¶
- 视频数据集¶
- 音频数据集¶
dataset_info.json 包含了所有经过预处理的 本地数据集 以及 在线数据集。如果您希望使用自定义数据集,请 务必 在 dataset_info.json 文件中添加对数据集及其内容的定义。
目前我们支持 Alpaca 格式和 ShareGPT 格式的数据集。
指令监督微调(Instruct Tuning)通过让模型学习详细的指令以及对应的回答来优化模型在特定指令下的表现。
instruction 列对应的内容为人类指令, input 列对应的内容为人类输入, output 列对应的内容为模型回答。下面是一个例子
在进行指令监督微调时, instruction 列对应的内容会与 input 列对应的内容拼接后作为最终的人类输入,即人类输入为 instruction\ninput。而 output 列对应的内容为模型回答。 在上面的例子中,人类的最终输入是:
如果指定, system 列对应的内容将被作为系统提示词。
history 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮对话的指令和回答。注意在指令监督微调时,历史消息中的回答内容也会被用于模型学习。
下面提供一个 alpaca 格式 多轮 对话的例子,对于单轮对话只需省略 history 列即可。
对于上述格式的数据, dataset_info.json 中的 数据集描述 应为:
大语言模型通过学习未被标记的文本进行预训练,从而学习语言的表征。通常,预训练数据集从互联网上获得,因为互联网上提供了大量的不同领域的文本信息,有助于提升模型的泛化能力。 预训练数据集文本描述格式如下:
在预训练时,只有 text 列中的 内容 (即document)会用于模型学习。
对于上述格式的数据, dataset_info.json 中的 数据集描述 应为:
偏好数据集用于奖励模型训练、DPO 训练和 ORPO 训练。对于系统指令和人类输入,偏好数据集给出了一个更优的回答和一个更差的回答。
一些研究 表明通过让模型学习“什么更好”可以使得模型更加迎合人类的需求。 甚至可以使得参数相对较少的模型的表现优于参数更多的模型。
偏好数据集需要在 chosen 列中提供更优的回答,并在 rejected 列中提供更差的回答,在一轮问答中其格式如下:
对于上述格式的数据,dataset_info.json 中的 数据集描述 应为:
KTO数据集与偏好数据集类似,但不同于给出一个更优的回答和一个更差的回答,KTO数据集对每一轮问答只给出一个 true/false 的 label。 除了 instruction 以及 input 组成的人类最终输入和模型回答 output ,KTO 数据集还需要额外添加一个 kto_tag 列(true/false)来表示人类的反馈。
对于上述格式的数据, dataset_info.json 中的 数据集描述 应为:
目前我们支持 多模态图像数据集、 视频数据集 以及 音频数据集 的输入。
多模态图像数据集需要额外添加一个 images 列,包含输入图像的路径。 注意图片的数量必须与文本中所有 <image> 标记的数量严格一致。
对于上述格式的数据, dataset_info.json 中的 数据集描述 应为:
多模态视频数据集需要额外添加一个 videos 列,包含输入视频的路径。 注意视频的数量必须与文本中所有 <video> 标记的数量严格一致。
对于上述格式的数据, dataset_info.json 中的 数据集描述 应为:
多模态音频数据集需要额外添加一个 audio 列,包含输入图像的路径。 注意音频的数量必须与文本中所有 <audio> 标记的数量严格一致。
对于上述格式的数据, dataset_info.json 中的 数据集描述 应为:
ShareGPT 格式中的 KTO数据集(样例)和多模态数据集(样例) 与 Alpaca 格式的类似。
预训练数据集不支持 ShareGPT 格式。
相比 alpaca 格式的数据集, sharegpt 格式支持 更多 的角色种类,例如 human、gpt、observation、function 等等。它们构成一个对象列表呈现在 conversations 列中。 下面是 sharegpt 格式的一个例子:
注意其中 human 和 observation 必须出现在奇数位置,gpt 和 function 必须出现在偶数位置。
对于上述格式的数据, dataset_info.json 中的 数据集描述 应为:
Sharegpt 格式的偏好数据集同样需要在 chosen 列中提供更优的消息,并在 rejected 列中提供更差的消息。 下面是一个例子:
对于上述格式的数据,dataset_info.json 中的 数据集描述 应为:
OpenAI 格式仅仅是 sharegpt 格式的一种特殊情况,其中第一条消息可能是系统提示词。
对于上述格式的数据, dataset_info.json 中的 数据集描述 应为:
---
## Supervised Fine-tuning¶
**URL:** https://llamafactory.readthedocs.io/en/latest/getting_started/sft.html
**Contents:**
- Supervised Fine-tuning¶
- 命令行¶
您可以使用以下命令使用 examples/train_lora/llama3_lora_sft.yaml 中的参数进行微调:
也可以通过追加参数更新 yaml 文件中的参数:
LLaMA-Factory 默认使用所有可见的计算设备。根据需求可通过 CUDA_VISIBLE_DEVICES 或 ASCEND_RT_VISIBLE_DEVICES 指定计算设备。
examples/train_lora/llama3_lora_sft.yaml 提供了微调时的配置示例。该配置指定了模型参数、微调方法参数、数据集参数以及评估参数等。您需要根据自身需求自行配置。
模型 model_name_or_path 、数据集 dataset 需要存在且与 template 相对应。
训练阶段,可选: rm(reward modeling), pt(pretrain), sft(Supervised Fine-Tuning), PPO, DPO, KTO, ORPO
微调方式。可选: freeze, lora, full
采取LoRA方法的目标模块,默认值为 all。
数据集模板,请保证数据集模板与模型相对应。
per_device_train_batch_size
gradient_accumulation_steps
学习率曲线,可选 linear, cosine, polynomial, constant 等。
---
@@ -0,0 +1,19 @@
# Llama-Factory Documentation Index
## Categories
### Images
**File:** `_images.md`
**Pages:** 3
### Advanced
**File:** `advanced.md`
**Pages:** 14
### Getting Started
**File:** `getting_started.md`
**Pages:** 7
### Other
**File:** `other.md`
**Pages:** 1
@@ -0,0 +1,31 @@
# Llama-Factory - Other
**Pages:** 1
---
## Welcome to LLaMA Factory!¶
**URL:** https://llamafactory.readthedocs.io/en/latest/
**Contents:**
- Welcome to LLaMA Factory!¶
- Documentation¶
LLaMA Factory is an easy-to-use and efficient platform for training and fine-tuning large language models. With LLaMA Factory, you can fine-tune hundreds of pre-trained models locally without writing any code. Framework features include:
Models: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
Trainers: (incremental) pre-training, (multimodal) instruction supervision fine-tuning, reward model training, PPO training, DPO training, KTO training, ORPO training, etc.
Computation Precision: 16-bit full-parameter fine-tuning, frozen fine-tuning, LoRA fine-tuning, and 2/3/4/5/6/8-bit QLoRA fine-tuning based on AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
Optimization Algorithms: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, and PiSSA.
Acceleration Operators: FlashAttention-2 and Unsloth.
Inference Engines: Transformers and vLLM.
Experiment Monitors: LlamaBoard, TensorBoard, Wandb, MLflow, SwanLab etc.
---
@@ -0,0 +1,431 @@
---
name: peft-fine-tuning
description: Parameter-efficient fine-tuning for LLMs using LoRA, QLoRA, and 25+ methods. Use when fine-tuning large models (7B-70B) with limited GPU memory, when you need to train <1% of parameters with minimal accuracy loss, or for multi-adapter serving. HuggingFace's official library integrated with transformers ecosystem.
version: 1.0.0
author: Orchestra Research
license: MIT
tags: [Fine-Tuning, PEFT, LoRA, QLoRA, Parameter-Efficient, Adapters, Low-Rank, Memory Optimization, Multi-Adapter]
dependencies: [peft>=0.13.0, transformers>=4.45.0, torch>=2.0.0, bitsandbytes>=0.43.0]
---
# PEFT (Parameter-Efficient Fine-Tuning)
Fine-tune LLMs by training <1% of parameters using LoRA, QLoRA, and 25+ adapter methods.
## When to use PEFT
**Use PEFT/LoRA when:**
- Fine-tuning 7B-70B models on consumer GPUs (RTX 4090, A100)
- Need to train <1% parameters (6MB adapters vs 14GB full model)
- Want fast iteration with multiple task-specific adapters
- Deploying multiple fine-tuned variants from one base model
**Use QLoRA (PEFT + quantization) when:**
- Fine-tuning 70B models on single 24GB GPU
- Memory is the primary constraint
- Can accept ~5% quality trade-off vs full fine-tuning
**Use full fine-tuning instead when:**
- Training small models (<1B parameters)
- Need maximum quality and have compute budget
- Significant domain shift requires updating all weights
## Quick start
### Installation
```bash
# Basic installation
pip install peft
# With quantization support (recommended)
pip install peft bitsandbytes
# Full stack
pip install peft transformers accelerate bitsandbytes datasets
```
### LoRA fine-tuning (standard)
```python
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import get_peft_model, LoraConfig, TaskType
from datasets import load_dataset
# Load base model
model_name = "meta-llama/Llama-3.1-8B"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
# LoRA configuration
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=16, # Rank (8-64, higher = more capacity)
lora_alpha=32, # Scaling factor (typically 2*r)
lora_dropout=0.05, # Dropout for regularization
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], # Attention layers
bias="none" # Don't train biases
)
# Apply LoRA
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Output: trainable params: 13,631,488 || all params: 8,043,307,008 || trainable%: 0.17%
# Prepare dataset
dataset = load_dataset("databricks/databricks-dolly-15k", split="train")
def tokenize(example):
text = f"### Instruction:\n{example['instruction']}\n\n### Response:\n{example['response']}"
return tokenizer(text, truncation=True, max_length=512, padding="max_length")
tokenized = dataset.map(tokenize, remove_columns=dataset.column_names)
# Training
training_args = TrainingArguments(
output_dir="./lora-llama",
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-4,
fp16=True,
logging_steps=10,
save_strategy="epoch"
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized,
data_collator=lambda data: {"input_ids": torch.stack([f["input_ids"] for f in data]),
"attention_mask": torch.stack([f["attention_mask"] for f in data]),
"labels": torch.stack([f["input_ids"] for f in data])}
)
trainer.train()
# Save adapter only (6MB vs 16GB)
model.save_pretrained("./lora-llama-adapter")
```
### QLoRA fine-tuning (memory-efficient)
```python
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
# 4-bit quantization config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4", # NormalFloat4 (best for LLMs)
bnb_4bit_compute_dtype="bfloat16", # Compute in bf16
bnb_4bit_use_double_quant=True # Nested quantization
)
# Load quantized model
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-70B",
quantization_config=bnb_config,
device_map="auto"
)
# Prepare for training (enables gradient checkpointing)
model = prepare_model_for_kbit_training(model)
# LoRA config for QLoRA
lora_config = LoraConfig(
r=64, # Higher rank for 70B
lora_alpha=128,
lora_dropout=0.1,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
# 70B model now fits on single 24GB GPU!
```
## LoRA parameter selection
### Rank (r) - capacity vs efficiency
| Rank | Trainable Params | Memory | Quality | Use Case |
|------|-----------------|--------|---------|----------|
| 4 | ~3M | Minimal | Lower | Simple tasks, prototyping |
| **8** | ~7M | Low | Good | **Recommended starting point** |
| **16** | ~14M | Medium | Better | **General fine-tuning** |
| 32 | ~27M | Higher | High | Complex tasks |
| 64 | ~54M | High | Highest | Domain adaptation, 70B models |
### Alpha (lora_alpha) - scaling factor
```python
# Rule of thumb: alpha = 2 * rank
LoraConfig(r=16, lora_alpha=32) # Standard
LoraConfig(r=16, lora_alpha=16) # Conservative (lower learning rate effect)
LoraConfig(r=16, lora_alpha=64) # Aggressive (higher learning rate effect)
```
### Target modules by architecture
```python
# Llama / Mistral / Qwen
target_modules = ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
# GPT-2 / GPT-Neo
target_modules = ["c_attn", "c_proj", "c_fc"]
# Falcon
target_modules = ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"]
# BLOOM
target_modules = ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"]
# Auto-detect all linear layers
target_modules = "all-linear" # PEFT 0.6.0+
```
## Loading and merging adapters
### Load trained adapter
```python
from peft import PeftModel, AutoPeftModelForCausalLM
from transformers import AutoModelForCausalLM
# Option 1: Load with PeftModel
base_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B")
model = PeftModel.from_pretrained(base_model, "./lora-llama-adapter")
# Option 2: Load directly (recommended)
model = AutoPeftModelForCausalLM.from_pretrained(
"./lora-llama-adapter",
device_map="auto"
)
```
### Merge adapter into base model
```python
# Merge for deployment (no adapter overhead)
merged_model = model.merge_and_unload()
# Save merged model
merged_model.save_pretrained("./llama-merged")
tokenizer.save_pretrained("./llama-merged")
# Push to Hub
merged_model.push_to_hub("username/llama-finetuned")
```
### Multi-adapter serving
```python
from peft import PeftModel
# Load base with first adapter
model = AutoPeftModelForCausalLM.from_pretrained("./adapter-task1")
# Load additional adapters
model.load_adapter("./adapter-task2", adapter_name="task2")
model.load_adapter("./adapter-task3", adapter_name="task3")
# Switch between adapters at runtime
model.set_adapter("task1") # Use task1 adapter
output1 = model.generate(**inputs)
model.set_adapter("task2") # Switch to task2
output2 = model.generate(**inputs)
# Disable adapters (use base model)
with model.disable_adapter():
base_output = model.generate(**inputs)
```
## PEFT methods comparison
| Method | Trainable % | Memory | Speed | Best For |
|--------|------------|--------|-------|----------|
| **LoRA** | 0.1-1% | Low | Fast | General fine-tuning |
| **QLoRA** | 0.1-1% | Very Low | Medium | Memory-constrained |
| AdaLoRA | 0.1-1% | Low | Medium | Automatic rank selection |
| IA3 | 0.01% | Minimal | Fastest | Few-shot adaptation |
| Prefix Tuning | 0.1% | Low | Medium | Generation control |
| Prompt Tuning | 0.001% | Minimal | Fast | Simple task adaptation |
| P-Tuning v2 | 0.1% | Low | Medium | NLU tasks |
### IA3 (minimal parameters)
```python
from peft import IA3Config
ia3_config = IA3Config(
target_modules=["q_proj", "v_proj", "k_proj", "down_proj"],
feedforward_modules=["down_proj"]
)
model = get_peft_model(model, ia3_config)
# Trains only 0.01% of parameters!
```
### Prefix Tuning
```python
from peft import PrefixTuningConfig
prefix_config = PrefixTuningConfig(
task_type="CAUSAL_LM",
num_virtual_tokens=20, # Prepended tokens
prefix_projection=True # Use MLP projection
)
model = get_peft_model(model, prefix_config)
```
## Integration patterns
### With TRL (SFTTrainer)
```python
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
lora_config = LoraConfig(r=16, lora_alpha=32, target_modules="all-linear")
trainer = SFTTrainer(
model=model,
args=SFTConfig(output_dir="./output", max_seq_length=512),
train_dataset=dataset,
peft_config=lora_config, # Pass LoRA config directly
)
trainer.train()
```
### With Axolotl (YAML config)
```yaml
# axolotl config.yaml
adapter: lora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
lora_target_linear: true # Target all linear layers
```
### With vLLM (inference)
```python
from vllm import LLM
from vllm.lora.request import LoRARequest
# Load base model with LoRA support
llm = LLM(model="meta-llama/Llama-3.1-8B", enable_lora=True)
# Serve with adapter
outputs = llm.generate(
prompts,
lora_request=LoRARequest("adapter1", 1, "./lora-adapter")
)
```
## Performance benchmarks
### Memory usage (Llama 3.1 8B)
| Method | GPU Memory | Trainable Params |
|--------|-----------|------------------|
| Full fine-tuning | 60+ GB | 8B (100%) |
| LoRA r=16 | 18 GB | 14M (0.17%) |
| QLoRA r=16 | 6 GB | 14M (0.17%) |
| IA3 | 16 GB | 800K (0.01%) |
### Training speed (A100 80GB)
| Method | Tokens/sec | vs Full FT |
|--------|-----------|------------|
| Full FT | 2,500 | 1x |
| LoRA | 3,200 | 1.3x |
| QLoRA | 2,100 | 0.84x |
### Quality (MMLU benchmark)
| Model | Full FT | LoRA | QLoRA |
|-------|---------|------|-------|
| Llama 2-7B | 45.3 | 44.8 | 44.1 |
| Llama 2-13B | 54.8 | 54.2 | 53.5 |
## Common issues
### CUDA OOM during training
```python
# Solution 1: Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Solution 2: Reduce batch size + increase accumulation
TrainingArguments(
per_device_train_batch_size=1,
gradient_accumulation_steps=16
)
# Solution 3: Use QLoRA
from transformers import BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4")
```
### Adapter not applying
```python
# Verify adapter is active
print(model.active_adapters) # Should show adapter name
# Check trainable parameters
model.print_trainable_parameters()
# Ensure model in training mode
model.train()
```
### Quality degradation
```python
# Increase rank
LoraConfig(r=32, lora_alpha=64)
# Target more modules
target_modules = "all-linear"
# Use more training data and epochs
TrainingArguments(num_train_epochs=5)
# Lower learning rate
TrainingArguments(learning_rate=1e-4)
```
## Best practices
1. **Start with r=8-16**, increase if quality insufficient
2. **Use alpha = 2 * rank** as starting point
3. **Target attention + MLP layers** for best quality/efficiency
4. **Enable gradient checkpointing** for memory savings
5. **Save adapters frequently** (small files, easy rollback)
6. **Evaluate on held-out data** before merging
7. **Use QLoRA for 70B+ models** on consumer hardware
## References
- **[Advanced Usage](references/advanced-usage.md)** - DoRA, LoftQ, rank stabilization, custom modules
- **[Troubleshooting](references/troubleshooting.md)** - Common errors, debugging, optimization
## Resources
- **GitHub**: https://github.com/huggingface/peft
- **Docs**: https://huggingface.co/docs/peft
- **LoRA Paper**: arXiv:2106.09685
- **QLoRA Paper**: arXiv:2305.14314
- **Models**: https://huggingface.co/models?library=peft
@@ -0,0 +1,514 @@
# PEFT Advanced Usage Guide
## Advanced LoRA Variants
### DoRA (Weight-Decomposed Low-Rank Adaptation)
DoRA decomposes weights into magnitude and direction components, often achieving better results than standard LoRA:
```python
from peft import LoraConfig
dora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
use_dora=True, # Enable DoRA
task_type="CAUSAL_LM"
)
model = get_peft_model(model, dora_config)
```
**When to use DoRA**:
- Consistently outperforms LoRA on instruction-following tasks
- Slightly higher memory (~10%) due to magnitude vectors
- Best for quality-critical fine-tuning
### AdaLoRA (Adaptive Rank)
Automatically adjusts rank per layer based on importance:
```python
from peft import AdaLoraConfig
adalora_config = AdaLoraConfig(
init_r=64, # Initial rank
target_r=16, # Target average rank
tinit=200, # Warmup steps
tfinal=1000, # Final pruning step
deltaT=10, # Rank update frequency
beta1=0.85,
beta2=0.85,
orth_reg_weight=0.5, # Orthogonality regularization
target_modules=["q_proj", "v_proj"],
task_type="CAUSAL_LM"
)
```
**Benefits**:
- Allocates more rank to important layers
- Can reduce total parameters while maintaining quality
- Good for exploring optimal rank distribution
### LoRA+ (Asymmetric Learning Rates)
Different learning rates for A and B matrices:
```python
from peft import LoraConfig
# LoRA+ uses higher LR for B matrix
lora_plus_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules="all-linear",
use_rslora=True, # Rank-stabilized LoRA (related technique)
)
# Manual implementation of LoRA+
from torch.optim import AdamW
# Group parameters
lora_A_params = [p for n, p in model.named_parameters() if "lora_A" in n]
lora_B_params = [p for n, p in model.named_parameters() if "lora_B" in n]
optimizer = AdamW([
{"params": lora_A_params, "lr": 1e-4},
{"params": lora_B_params, "lr": 1e-3}, # 10x higher for B
])
```
### rsLoRA (Rank-Stabilized LoRA)
Scales LoRA outputs to stabilize training with different ranks:
```python
lora_config = LoraConfig(
r=64,
lora_alpha=64,
use_rslora=True, # Enables rank-stabilized scaling
target_modules="all-linear"
)
```
**When to use**:
- When experimenting with different ranks
- Helps maintain consistent behavior across rank values
- Recommended for r > 32
## LoftQ (LoRA-Fine-Tuning-aware Quantization)
Initializes LoRA weights to compensate for quantization error:
```python
from peft import LoftQConfig, LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
# LoftQ configuration
loftq_config = LoftQConfig(
loftq_bits=4, # Quantization bits
loftq_iter=5, # Alternating optimization iterations
)
# LoRA config with LoftQ initialization
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules="all-linear",
init_lora_weights="loftq",
loftq_config=loftq_config,
task_type="CAUSAL_LM"
)
# Load quantized model
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4")
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B",
quantization_config=bnb_config
)
model = get_peft_model(model, lora_config)
```
**Benefits over standard QLoRA**:
- Better initial quality after quantization
- Faster convergence
- ~1-2% better final accuracy on benchmarks
## Custom Module Targeting
### Target specific layers
```python
# Target only first and last transformer layers
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["model.layers.0.self_attn.q_proj",
"model.layers.0.self_attn.v_proj",
"model.layers.31.self_attn.q_proj",
"model.layers.31.self_attn.v_proj"],
layers_to_transform=[0, 31] # Alternative approach
)
```
### Layer pattern matching
```python
# Target layers 0-10 only
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules="all-linear",
layers_to_transform=list(range(11)), # Layers 0-10
layers_pattern="model.layers"
)
```
### Exclude specific layers
```python
lora_config = LoraConfig(
r=16,
target_modules="all-linear",
modules_to_save=["lm_head"], # Train these fully (not LoRA)
)
```
## Embedding and LM Head Training
### Train embeddings with LoRA
```python
from peft import LoraConfig
# Include embeddings
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj", "embed_tokens"], # Include embeddings
modules_to_save=["lm_head"], # Train lm_head fully
)
```
### Extending vocabulary with LoRA
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import get_peft_model, LoraConfig
# Add new tokens
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")
new_tokens = ["<custom_token_1>", "<custom_token_2>"]
tokenizer.add_tokens(new_tokens)
# Resize model embeddings
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B")
model.resize_token_embeddings(len(tokenizer))
# Configure LoRA to train new embeddings
lora_config = LoraConfig(
r=16,
target_modules="all-linear",
modules_to_save=["embed_tokens", "lm_head"], # Train these fully
)
model = get_peft_model(model, lora_config)
```
## Multi-Adapter Patterns
### Adapter composition
```python
from peft import PeftModel
# Load model with multiple adapters
model = AutoPeftModelForCausalLM.from_pretrained("./base-adapter")
model.load_adapter("./style-adapter", adapter_name="style")
model.load_adapter("./task-adapter", adapter_name="task")
# Combine adapters (weighted sum)
model.add_weighted_adapter(
adapters=["style", "task"],
weights=[0.7, 0.3],
adapter_name="combined",
combination_type="linear" # or "cat", "svd"
)
model.set_adapter("combined")
```
### Adapter stacking
```python
# Stack adapters (apply sequentially)
model.add_weighted_adapter(
adapters=["base", "domain", "task"],
weights=[1.0, 1.0, 1.0],
adapter_name="stacked",
combination_type="cat" # Concatenate adapter outputs
)
```
### Dynamic adapter switching
```python
import torch
class MultiAdapterModel:
def __init__(self, base_model_path, adapter_paths):
self.model = AutoPeftModelForCausalLM.from_pretrained(adapter_paths[0])
for name, path in adapter_paths[1:].items():
self.model.load_adapter(path, adapter_name=name)
def generate(self, prompt, adapter_name="default"):
self.model.set_adapter(adapter_name)
return self.model.generate(**self.tokenize(prompt))
def generate_ensemble(self, prompt, adapters, weights):
"""Generate with weighted adapter ensemble"""
outputs = []
for adapter, weight in zip(adapters, weights):
self.model.set_adapter(adapter)
logits = self.model(**self.tokenize(prompt)).logits
outputs.append(weight * logits)
return torch.stack(outputs).sum(dim=0)
```
## Memory Optimization
### Gradient checkpointing with LoRA
```python
from peft import prepare_model_for_kbit_training
# Enable gradient checkpointing
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False}
)
```
### CPU offloading for training
```python
from accelerate import Accelerator
accelerator = Accelerator(
mixed_precision="bf16",
gradient_accumulation_steps=8,
cpu_offload=True # Offload optimizer states to CPU
)
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
```
### Memory-efficient attention with LoRA
```python
from transformers import AutoModelForCausalLM
# Combine Flash Attention 2 with LoRA
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B",
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16
)
# Apply LoRA
model = get_peft_model(model, lora_config)
```
## Inference Optimization
### Merge for deployment
```python
# Merge adapter weights into base model
merged_model = model.merge_and_unload()
# Quantize merged model for inference
from transformers import BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(load_in_4bit=True)
quantized_model = AutoModelForCausalLM.from_pretrained(
"./merged-model",
quantization_config=bnb_config
)
```
### Export to different formats
```python
# Export to GGUF (llama.cpp)
# First merge, then convert
merged_model.save_pretrained("./merged-model")
# Use llama.cpp converter
# python convert-hf-to-gguf.py ./merged-model --outfile model.gguf
# Export to ONNX
from optimum.onnxruntime import ORTModelForCausalLM
ort_model = ORTModelForCausalLM.from_pretrained(
"./merged-model",
export=True
)
ort_model.save_pretrained("./onnx-model")
```
### Batch adapter inference
```python
from vllm import LLM
from vllm.lora.request import LoRARequest
# Initialize with LoRA support
llm = LLM(
model="meta-llama/Llama-3.1-8B",
enable_lora=True,
max_lora_rank=64,
max_loras=4 # Max concurrent adapters
)
# Batch with different adapters
requests = [
("prompt1", LoRARequest("adapter1", 1, "./adapter1")),
("prompt2", LoRARequest("adapter2", 2, "./adapter2")),
("prompt3", LoRARequest("adapter1", 1, "./adapter1")),
]
outputs = llm.generate(
[r[0] for r in requests],
lora_request=[r[1] for r in requests]
)
```
## Training Recipes
### Instruction tuning recipe
```python
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
target_modules="all-linear",
bias="none",
task_type="CAUSAL_LM"
)
training_args = TrainingArguments(
output_dir="./output",
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-4,
lr_scheduler_type="cosine",
warmup_ratio=0.03,
bf16=True,
logging_steps=10,
save_strategy="steps",
save_steps=100,
eval_strategy="steps",
eval_steps=100,
)
```
### Code generation recipe
```python
lora_config = LoraConfig(
r=32, # Higher rank for code
lora_alpha=64,
lora_dropout=0.1,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
bias="none",
task_type="CAUSAL_LM"
)
training_args = TrainingArguments(
learning_rate=1e-4, # Lower LR for code
num_train_epochs=2,
max_seq_length=2048, # Longer sequences
)
```
### Conversational/Chat recipe
```python
from trl import SFTTrainer
lora_config = LoraConfig(
r=16,
lora_alpha=16, # alpha = r for chat
lora_dropout=0.05,
target_modules="all-linear"
)
# Use chat template
def format_chat(example):
messages = [
{"role": "user", "content": example["instruction"]},
{"role": "assistant", "content": example["response"]}
]
return tokenizer.apply_chat_template(messages, tokenize=False)
trainer = SFTTrainer(
model=model,
peft_config=lora_config,
train_dataset=dataset.map(format_chat),
max_seq_length=1024,
)
```
## Debugging and Validation
### Verify adapter application
```python
# Check which modules have LoRA
for name, module in model.named_modules():
if hasattr(module, "lora_A"):
print(f"LoRA applied to: {name}")
# Print detailed config
print(model.peft_config)
# Check adapter state
print(f"Active adapters: {model.active_adapters}")
print(f"Trainable: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
```
### Compare with base model
```python
# Generate with adapter
model.set_adapter("default")
adapter_output = model.generate(**inputs)
# Generate without adapter
with model.disable_adapter():
base_output = model.generate(**inputs)
print(f"Adapter: {tokenizer.decode(adapter_output[0])}")
print(f"Base: {tokenizer.decode(base_output[0])}")
```
### Monitor training metrics
```python
from transformers import TrainerCallback
class LoRACallback(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kwargs):
if "loss" in logs:
# Log adapter-specific metrics
model = kwargs["model"]
lora_params = sum(p.numel() for n, p in model.named_parameters()
if "lora" in n and p.requires_grad)
print(f"Step {state.global_step}: loss={logs['loss']:.4f}, lora_params={lora_params}")
```
@@ -0,0 +1,480 @@
# PEFT Troubleshooting Guide
## Installation Issues
### bitsandbytes CUDA Error
**Error**: `CUDA Setup failed despite GPU being available`
**Fix**:
```bash
# Check CUDA version
nvcc --version
# Install matching bitsandbytes
pip uninstall bitsandbytes
pip install bitsandbytes --no-cache-dir
# Or compile from source for specific CUDA
git clone https://github.com/TimDettmers/bitsandbytes.git
cd bitsandbytes
CUDA_VERSION=118 make cuda11x # Adjust for your CUDA
pip install .
```
### Triton Import Error
**Error**: `ModuleNotFoundError: No module named 'triton'`
**Fix**:
```bash
# Install triton (Linux only)
pip install triton
# Windows: Triton not supported, use CUDA backend
# Set environment variable to disable triton
export CUDA_VISIBLE_DEVICES=0
```
### PEFT Version Conflicts
**Error**: `AttributeError: 'LoraConfig' object has no attribute 'use_dora'`
**Fix**:
```bash
# Upgrade to latest PEFT
pip install peft>=0.13.0 --upgrade
# Check version
python -c "import peft; print(peft.__version__)"
```
## Training Issues
### CUDA Out of Memory
**Error**: `torch.cuda.OutOfMemoryError: CUDA out of memory`
**Solutions**:
1. **Enable gradient checkpointing**:
```python
from peft import prepare_model_for_kbit_training
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
```
2. **Reduce batch size**:
```python
TrainingArguments(
per_device_train_batch_size=1,
gradient_accumulation_steps=16 # Maintain effective batch size
)
```
3. **Use QLoRA**:
```python
from transformers import BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True
)
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config)
```
4. **Lower LoRA rank**:
```python
LoraConfig(r=8) # Instead of r=16 or higher
```
5. **Target fewer modules**:
```python
target_modules=["q_proj", "v_proj"] # Instead of all-linear
```
### Loss Not Decreasing
**Problem**: Training loss stays flat or increases.
**Solutions**:
1. **Check learning rate**:
```python
# Start lower
TrainingArguments(learning_rate=1e-4) # Not 2e-4 or higher
```
2. **Verify adapter is active**:
```python
model.print_trainable_parameters()
# Should show >0 trainable params
# Check adapter applied
print(model.peft_config)
```
3. **Check data formatting**:
```python
# Verify tokenization
sample = dataset[0]
decoded = tokenizer.decode(sample["input_ids"])
print(decoded) # Should look correct
```
4. **Increase rank**:
```python
LoraConfig(r=32, lora_alpha=64) # More capacity
```
### NaN Loss
**Error**: `Loss is NaN`
**Fix**:
```python
# Use bf16 instead of fp16
TrainingArguments(bf16=True, fp16=False)
# Or enable loss scaling
TrainingArguments(fp16=True, fp16_full_eval=True)
# Lower learning rate
TrainingArguments(learning_rate=5e-5)
# Check for data issues
for batch in dataloader:
if torch.isnan(batch["input_ids"].float()).any():
print("NaN in input!")
```
### Adapter Not Training
**Problem**: `trainable params: 0` or model not updating.
**Fix**:
```python
# Verify LoRA applied to correct modules
for name, module in model.named_modules():
if "lora" in name.lower():
print(f"Found LoRA: {name}")
# Check target_modules match model architecture
from peft.utils import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
print(TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.get(model.config.model_type))
# Ensure model in training mode
model.train()
# Check requires_grad
for name, param in model.named_parameters():
if param.requires_grad:
print(f"Trainable: {name}")
```
## Loading Issues
### Adapter Loading Fails
**Error**: `ValueError: Can't find adapter weights`
**Fix**:
```python
# Check adapter files exist
import os
print(os.listdir("./adapter-path"))
# Should contain: adapter_config.json, adapter_model.safetensors
# Load with correct structure
from peft import PeftModel, PeftConfig
# Check config
config = PeftConfig.from_pretrained("./adapter-path")
print(config)
# Load base model first
base_model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(base_model, "./adapter-path")
```
### Base Model Mismatch
**Error**: `RuntimeError: size mismatch`
**Fix**:
```python
# Ensure base model matches adapter
from peft import PeftConfig
config = PeftConfig.from_pretrained("./adapter-path")
print(f"Base model: {config.base_model_name_or_path}")
# Load exact same base model
base_model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
```
### Safetensors vs PyTorch Format
**Error**: `ValueError: We couldn't connect to 'https://huggingface.co'`
**Fix**:
```python
# Force local loading
model = PeftModel.from_pretrained(
base_model,
"./adapter-path",
local_files_only=True
)
# Or specify format
model.save_pretrained("./adapter", safe_serialization=True) # safetensors
model.save_pretrained("./adapter", safe_serialization=False) # pytorch
```
## Inference Issues
### Slow Generation
**Problem**: Inference much slower than expected.
**Solutions**:
1. **Merge adapter for deployment**:
```python
merged_model = model.merge_and_unload()
# No adapter overhead during inference
```
2. **Use optimized inference engine**:
```python
from vllm import LLM
llm = LLM(model="./merged-model", dtype="half")
```
3. **Enable Flash Attention**:
```python
model = AutoModelForCausalLM.from_pretrained(
model_name,
attn_implementation="flash_attention_2"
)
```
### Output Quality Issues
**Problem**: Fine-tuned model produces worse outputs.
**Solutions**:
1. **Check evaluation without adapter**:
```python
with model.disable_adapter():
base_output = model.generate(**inputs)
# Compare with adapter output
```
2. **Lower temperature during eval**:
```python
model.generate(**inputs, temperature=0.1, do_sample=False)
```
3. **Retrain with more data**:
```python
# Increase training samples
# Use higher quality data
# Train for more epochs
```
### Wrong Adapter Active
**Problem**: Model using wrong adapter or no adapter.
**Fix**:
```python
# Check active adapters
print(model.active_adapters)
# Explicitly set adapter
model.set_adapter("your-adapter-name")
# List all adapters
print(model.peft_config.keys())
```
## QLoRA Specific Issues
### Quantization Errors
**Error**: `RuntimeError: mat1 and mat2 shapes cannot be multiplied`
**Fix**:
```python
# Ensure compute dtype matches
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16, # Match model dtype
bnb_4bit_quant_type="nf4"
)
# Load with correct dtype
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
torch_dtype=torch.bfloat16
)
```
### QLoRA OOM
**Error**: OOM even with 4-bit quantization.
**Fix**:
```python
# Enable double quantization
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True # Further memory reduction
)
# Use offloading
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto",
max_memory={0: "20GB", "cpu": "100GB"}
)
```
### QLoRA Merge Fails
**Error**: `RuntimeError: expected scalar type BFloat16 but found Float`
**Fix**:
```python
# Dequantize before merging
from peft import PeftModel
# Load in higher precision for merging
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.float16, # Not quantized
device_map="auto"
)
# Load adapter
model = PeftModel.from_pretrained(base_model, "./qlora-adapter")
# Now merge
merged = model.merge_and_unload()
```
## Multi-Adapter Issues
### Adapter Conflict
**Error**: `ValueError: Adapter with name 'default' already exists`
**Fix**:
```python
# Use unique names
model.load_adapter("./adapter1", adapter_name="task1")
model.load_adapter("./adapter2", adapter_name="task2")
# Or delete existing
model.delete_adapter("default")
```
### Mixed Precision Adapters
**Error**: Adapters trained with different dtypes.
**Fix**:
```python
# Convert adapter precision
model = PeftModel.from_pretrained(base_model, "./adapter")
model = model.to(torch.bfloat16)
# Or load with specific dtype
model = PeftModel.from_pretrained(
base_model,
"./adapter",
torch_dtype=torch.bfloat16
)
```
## Performance Optimization
### Memory Profiling
```python
import torch
def print_memory():
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / 1e9
reserved = torch.cuda.memory_reserved() / 1e9
print(f"Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB")
# Profile during training
print_memory() # Before
model.train()
loss = model(**batch).loss
loss.backward()
print_memory() # After
```
### Speed Profiling
```python
import time
import torch
def benchmark_generation(model, tokenizer, prompt, n_runs=5):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Warmup
model.generate(**inputs, max_new_tokens=10)
torch.cuda.synchronize()
# Benchmark
times = []
for _ in range(n_runs):
start = time.perf_counter()
outputs = model.generate(**inputs, max_new_tokens=100)
torch.cuda.synchronize()
times.append(time.perf_counter() - start)
tokens = outputs.shape[1] - inputs.input_ids.shape[1]
avg_time = sum(times) / len(times)
print(f"Speed: {tokens/avg_time:.2f} tokens/sec")
# Compare adapter vs merged
benchmark_generation(adapter_model, tokenizer, "Hello")
benchmark_generation(merged_model, tokenizer, "Hello")
```
## Getting Help
1. **Check PEFT GitHub Issues**: https://github.com/huggingface/peft/issues
2. **HuggingFace Forums**: https://discuss.huggingface.co/
3. **PEFT Documentation**: https://huggingface.co/docs/peft
### Debugging Template
When reporting issues, include:
```python
# System info
import peft
import transformers
import torch
print(f"PEFT: {peft.__version__}")
print(f"Transformers: {transformers.__version__}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.version.cuda}")
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'N/A'}")
# Config
print(model.peft_config)
model.print_trainable_parameters()
```
@@ -0,0 +1,80 @@
---
name: unsloth
description: Expert guidance for fast fine-tuning with Unsloth - 2-5x faster training, 50-80% less memory, LoRA/QLoRA optimization
version: 1.0.0
author: Orchestra Research
license: MIT
tags: [Fine-Tuning, Unsloth, Fast Training, LoRA, QLoRA, Memory-Efficient, Optimization, Llama, Mistral, Gemma, Qwen]
dependencies: [unsloth, torch, transformers, trl, datasets, peft]
---
# Unsloth Skill
Comprehensive assistance with unsloth development, generated from official documentation.
## When to Use This Skill
This skill should be triggered when:
- Working with unsloth
- Asking about unsloth features or APIs
- Implementing unsloth solutions
- Debugging unsloth code
- Learning unsloth best practices
## Quick Reference
### Common Patterns
*Quick reference patterns will be added as you use the skill.*
## Reference Files
This skill includes comprehensive documentation in `references/`:
- **llms-txt.md** - Llms-Txt documentation
Use `view` to read specific reference files when detailed information is needed.
## Working with This Skill
### For Beginners
Start with the getting_started or tutorials reference files for foundational concepts.
### For Specific Features
Use the appropriate category reference file (api, guides, etc.) for detailed information.
### For Code Examples
The quick reference section above contains common patterns extracted from the official docs.
## Resources
### references/
Organized documentation extracted from official sources. These files contain:
- Detailed explanations
- Code examples with language annotations
- Links to original documentation
- Table of contents for quick navigation
### scripts/
Add helper scripts here for common automation tasks.
### assets/
Add templates, boilerplate, or example projects here.
## Notes
- This skill was automatically generated from official documentation
- Reference files preserve the structure and examples from source docs
- Code examples include language detection for better syntax highlighting
- Quick reference patterns are extracted from common usage examples in the docs
## Updating
To refresh this skill with updated documentation:
1. Re-run the scraper with the same configuration
2. The skill will be rebuilt with the latest information
<!-- Trigger re-upload 1763621536 -->
@@ -0,0 +1,7 @@
# Unsloth Documentation Index
## Categories
### Llms-Txt
**File:** `llms-txt.md`
**Pages:** 136
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -0,0 +1,82 @@
# Unsloth Documentation
## Unsloth Documentation
- [Unsloth Docs](/get-started/unsloth-docs.md): Train your own model with Unsloth, an open-source framework for LLM fine-tuning and reinforcement learning.
- [Beginner? Start here!](/get-started/beginner-start-here.md)
- [Unsloth Requirements](/get-started/beginner-start-here/unsloth-requirements.md): Here are Unsloth's requirements including system and GPU VRAM requirements.
- [FAQ + Is Fine-tuning Right For Me?](/get-started/beginner-start-here/faq-+-is-fine-tuning-right-for-me.md): If you're stuck on if fine-tuning is right for you, see here! Learn about fine-tuning misconceptions, how it compared to RAG and more:
- [Unsloth Notebooks](/get-started/unsloth-notebooks.md): Explore our catalog of Unsloth notebooks:
- [All Our Models](/get-started/all-our-models.md)
- [Install & Update](/get-started/install-and-update.md): Learn to install Unsloth locally or online.
- [Updating](/get-started/install-and-update/updating.md): To update or use an old version of Unsloth, follow the steps below:
- [Pip Install](/get-started/install-and-update/pip-install.md): To install Unsloth locally via Pip, follow the steps below:
- [Docker](/get-started/install-and-update/docker.md): Install Unsloth using our official Docker container
- [Windows Installation](/get-started/install-and-update/windows-installation.md): See how to install Unsloth on Windows with or without WSL.
- [AMD](/get-started/install-and-update/amd.md): Fine-tune with Unsloth on AMD GPUs.
- [Conda Install](/get-started/install-and-update/conda-install.md): To install Unsloth locally on Conda, follow the steps below:
- [Google Colab](/get-started/install-and-update/google-colab.md): To install and run Unsloth on Google Colab, follow the steps below:
- [Fine-tuning LLMs Guide](/get-started/fine-tuning-llms-guide.md): Learn all the basics and best practices of fine-tuning. Beginner-friendly.
- [What Model Should I Use?](/get-started/fine-tuning-llms-guide/what-model-should-i-use.md)
- [Datasets Guide](/get-started/fine-tuning-llms-guide/datasets-guide.md): Learn how to create & prepare a dataset for fine-tuning.
- [LoRA Hyperparameters Guide](/get-started/fine-tuning-llms-guide/lora-hyperparameters-guide.md): Optimal lora rank. alpha, number of epochs, batch size & gradient accumulation, QLoRA vs LoRA, target modules and more!
- [Tutorial: How to Finetune Llama-3 and Use In Ollama](/get-started/fine-tuning-llms-guide/tutorial-how-to-finetune-llama-3-and-use-in-ollama.md): Beginner's Guide for creating a customized personal assistant (like ChatGPT) to run locally on Ollama
- [Reinforcement Learning (RL) Guide](/get-started/reinforcement-learning-rl-guide.md): Learn all about Reinforcement Learning (RL) and how to train your own DeepSeek-R1 reasoning model with Unsloth using GRPO. A complete guide from beginner to advanced.
- [Tutorial: Train your own Reasoning model with GRPO](/get-started/reinforcement-learning-rl-guide/tutorial-train-your-own-reasoning-model-with-grpo.md): Beginner's Guide to transforming a model like Llama 3.1 (8B) into a reasoning model by using Unsloth and GRPO.
- [Advanced RL Documentation](/get-started/reinforcement-learning-rl-guide/advanced-rl-documentation.md): Advanced documentation settings when using Unsloth with GRPO.
- [Memory Efficient RL](/get-started/reinforcement-learning-rl-guide/memory-efficient-rl.md)
- [RL Reward Hacking](/get-started/reinforcement-learning-rl-guide/rl-reward-hacking.md): Learn what is Reward Hacking in Reinforcement Learning and how to counter it.
- [GSPO Reinforcement Learning](/get-started/reinforcement-learning-rl-guide/gspo-reinforcement-learning.md): Train with GSPO (Group Sequence Policy Optimization) RL in Unsloth.
- [Reinforcement Learning - DPO, ORPO & KTO](/get-started/reinforcement-learning-rl-guide/reinforcement-learning-dpo-orpo-and-kto.md): To use the reward modelling functions for DPO, GRPO, ORPO or KTO with Unsloth, follow the steps below:
- [DeepSeek-OCR: How to Run & Fine-tune](/new/deepseek-ocr-how-to-run-and-fine-tune.md): Guide on how to run and fine-tune DeepSeek-OCR locally.
- [How to Fine-tune LLMs with Unsloth & Docker](/new/how-to-fine-tune-llms-with-unsloth-and-docker.md): Learn how to fine-tune LLMs or do Reinforcement Learning (RL) with Unsloth's Docker image.
- [Vision Reinforcement Learning (VLM RL)](/new/vision-reinforcement-learning-vlm-rl.md): Train Vision/multimodal models via GRPO and RL with Unsloth!
- [gpt-oss Reinforcement Learning](/new/gpt-oss-reinforcement-learning.md)
- [Tutorial: How to Train gpt-oss with RL](/new/gpt-oss-reinforcement-learning/tutorial-how-to-train-gpt-oss-with-rl.md): Learn to train OpenAI gpt-oss with GRPO to autonomously beat 2048 locally or on Colab.
- [Unsloth Dynamic GGUFs on Aider Polyglot](/new/unsloth-dynamic-ggufs-on-aider-polyglot.md): Performance of Unsloth Dynamic GGUFs on Aider Polyglot Benchmarks
- [Qwen3-VL: How to Run & Fine-tune](/models/qwen3-vl-how-to-run-and-fine-tune.md): Learn to fine-tune and run Qwen3-VL locally with Unsloth.
- [gpt-oss: How to Run & Fine-tune](/models/gpt-oss-how-to-run-and-fine-tune.md): Run & fine-tune OpenAI's new open-source models!
- [Tutorial: How to Fine-tune gpt-oss](/models/gpt-oss-how-to-run-and-fine-tune/tutorial-how-to-fine-tune-gpt-oss.md): Learn step-by-step how to train OpenAI gpt-oss locally with Unsloth.
- [Long Context gpt-oss Training](/models/gpt-oss-how-to-run-and-fine-tune/long-context-gpt-oss-training.md)
- [GLM-4.6: How to Run Locally](/models/glm-4.6-how-to-run-locally.md): A guide on how to run Z.ai's new GLM-4.6 model on your own local device!
- [IBM Granite 4.0](/models/ibm-granite-4.0.md): How to run IBM Granite-4.0 with Unsloth GGUFs on llama.cpp, Ollama and how to fine-tune!
- [DeepSeek-V3.1: How to Run Locally](/models/deepseek-v3.1-how-to-run-locally.md): A guide on how to run DeepSeek-V3.1 and Terminus on your own local device!
- [Qwen3-Coder: How to Run Locally](/models/qwen3-coder-how-to-run-locally.md): Run Qwen3-Coder-30B-A3B-Instruct and 480B-A35B locally with Unsloth Dynamic quants.
- [Gemma 3: How to Run & Fine-tune](/models/gemma-3-how-to-run-and-fine-tune.md): How to run Gemma 3 effectively with our GGUFs on llama.cpp, Ollama, Open WebUI and how to fine-tune with Unsloth!
- [Gemma 3n: How to Run & Fine-tune](/models/gemma-3-how-to-run-and-fine-tune/gemma-3n-how-to-run-and-fine-tune.md): Run Google's new Gemma 3n locally with Dynamic GGUFs on llama.cpp, Ollama, Open WebUI and fine-tune with Unsloth!
- [Qwen3: How to Run & Fine-tune](/models/qwen3-how-to-run-and-fine-tune.md): Learn to run & fine-tune Qwen3 locally with Unsloth + our Dynamic 2.0 quants
- [Qwen3-2507](/models/qwen3-how-to-run-and-fine-tune/qwen3-2507.md): Run Qwen3-30B-A3B-2507 and 235B-A22B Thinking and Instruct versions locally on your device!
- [Tutorials: How To Fine-tune & Run LLMs](/models/tutorials-how-to-fine-tune-and-run-llms.md): Learn how to run and fine-tune models for optimal performance 100% locally with Unsloth.
- [DeepSeek-R1-0528: How to Run Locally](/models/tutorials-how-to-fine-tune-and-run-llms/deepseek-r1-0528-how-to-run-locally.md): A guide on how to run DeepSeek-R1-0528 including Qwen3 on your own local device!
- [Magistral: How to Run & Fine-tune](/models/tutorials-how-to-fine-tune-and-run-llms/magistral-how-to-run-and-fine-tune.md): Meet Magistral - Mistral's new reasoning models.
- [Llama 4: How to Run & Fine-tune](/models/tutorials-how-to-fine-tune-and-run-llms/llama-4-how-to-run-and-fine-tune.md): How to run Llama 4 locally using our dynamic GGUFs which recovers accuracy compared to standard quantization.
- [Kimi K2: How to Run Locally](/models/tutorials-how-to-fine-tune-and-run-llms/kimi-k2-how-to-run-locally.md): Guide on running Kimi K2 and Kimi-K2-Instruct-0905 on your own local device!
- [Grok 2](/models/tutorials-how-to-fine-tune-and-run-llms/grok-2.md): Run xAI's Grok 2 model locally!
- [Devstral: How to Run & Fine-tune](/models/tutorials-how-to-fine-tune-and-run-llms/devstral-how-to-run-and-fine-tune.md): Run and fine-tune Mistral Devstral 1.1, including Small-2507 and 2505.
- [DeepSeek-V3-0324: How to Run Locally](/models/tutorials-how-to-fine-tune-and-run-llms/deepseek-v3-0324-how-to-run-locally.md): How to run DeepSeek-V3-0324 locally using our dynamic quants which recovers accuracy
- [DeepSeek-R1: How to Run Locally](/models/tutorials-how-to-fine-tune-and-run-llms/deepseek-r1-how-to-run-locally.md): A guide on how you can run our 1.58-bit Dynamic Quants for DeepSeek-R1 using llama.cpp.
- [DeepSeek-R1 Dynamic 1.58-bit](/models/tutorials-how-to-fine-tune-and-run-llms/deepseek-r1-how-to-run-locally/deepseek-r1-dynamic-1.58-bit.md): See performance comparison tables for Unsloth's Dynamic GGUF Quants vs Standard IMatrix Quants.
- [QwQ-32B: How to Run effectively](/models/tutorials-how-to-fine-tune-and-run-llms/qwq-32b-how-to-run-effectively.md): How to run QwQ-32B effectively with our bug fixes and without endless generations + GGUFs.
- [Phi-4 Reasoning: How to Run & Fine-tune](/models/tutorials-how-to-fine-tune-and-run-llms/phi-4-reasoning-how-to-run-and-fine-tune.md): Learn to run & fine-tune Phi-4 reasoning models locally with Unsloth + our Dynamic 2.0 quants
- [Running & Saving Models](/basics/running-and-saving-models.md): Learn how to save your finetuned model so you can run it in your favorite inference engine.
- [Saving to GGUF](/basics/running-and-saving-models/saving-to-gguf.md): Saving models to 16bit for GGUF so you can use it for Ollama, Jan AI, Open WebUI and more!
- [Saving to Ollama](/basics/running-and-saving-models/saving-to-ollama.md)
- [Saving to vLLM for deployment](/basics/running-and-saving-models/saving-to-vllm-for-deployment.md): Saving models to 16bit for vLLM deployment and serving
- [Saving to SGLang for deployment](/basics/running-and-saving-models/saving-to-sglang-for-deployment.md): Saving models to 16bit for SGLang for deployment and serving
- [Unsloth Inference](/basics/running-and-saving-models/unsloth-inference.md): Learn how to run your finetuned model with Unsloth's faster inference.
- [Troubleshooting Inference](/basics/running-and-saving-models/troubleshooting-inference.md): If you're experiencing issues when running or saving your model.
- [vLLM Engine Arguments](/basics/running-and-saving-models/vllm-engine-arguments.md)
- [LoRA Hot Swapping Guide](/basics/running-and-saving-models/lora-hot-swapping-guide.md)
- [Text-to-Speech (TTS) Fine-tuning](/basics/text-to-speech-tts-fine-tuning.md): Learn how to to fine-tune TTS & STT voice models with Unsloth.
- [Unsloth Dynamic 2.0 GGUFs](/basics/unsloth-dynamic-2.0-ggufs.md): A big new upgrade to our Dynamic Quants!
- [Vision Fine-tuning](/basics/vision-fine-tuning.md): Learn how to fine-tune vision/multimodal LLMs with Unsloth
- [Fine-tuning LLMs with NVIDIA DGX Spark and Unsloth](/basics/fine-tuning-llms-with-nvidia-dgx-spark-and-unsloth.md): Tutorial on how to fine-tune and do reinforcement learning (RL) with OpenAI gpt-oss on NVIDIA DGX Spark.
- [Fine-tuning LLMs with Blackwell, RTX 50 series & Unsloth](/basics/fine-tuning-llms-with-blackwell-rtx-50-series-and-unsloth.md): Learn how to fine-tune LLMs on NVIDIA's Blackwell RTX 50 series and B200 GPUs with our step-by-step guide.
- [Multi-GPU Training with Unsloth](/basics/multi-gpu-training-with-unsloth.md): Learn how to fine-tune LLMs on multiple GPUs and parallelism with Unsloth.
- [Finetuning from Last Checkpoint](/basics/finetuning-from-last-checkpoint.md): Checkpointing allows you to save your finetuning progress so you can pause it and then continue.
- [Troubleshooting & FAQs](/basics/troubleshooting-and-faqs.md): Tips to solve issues, and frequently asked questions.
- [Chat Templates](/basics/chat-templates.md): Learn the fundamentals and customization options of chat templates, including Conversational, ChatML, ShareGPT, Alpaca formats, and more!
- [Quantization-Aware Training (QAT)](/basics/quantization-aware-training-qat.md): Quantize models to 4-bit with Unsloth and PyTorch to recover accuracy.
- [Unsloth Environment Flags](/basics/unsloth-environment-flags.md): Advanced flags which might be useful if you see breaking finetunes, or you want to turn stuff off.
- [Continued Pretraining](/basics/continued-pretraining.md): AKA as Continued Finetuning. Unsloth allows you to continually pretrain so a model can learn a new language.
- [Unsloth Benchmarks](/basics/unsloth-benchmarks.md): Unsloth recorded benchmarks on NVIDIA GPUs.
@@ -0,0 +1,436 @@
---
name: nnsight-remote-interpretability
description: Provides guidance for interpreting and manipulating neural network internals using nnsight with optional NDIF remote execution. Use when needing to run interpretability experiments on massive models (70B+) without local GPU resources, or when working with any PyTorch architecture.
version: 1.0.0
author: Orchestra Research
license: MIT
tags: [nnsight, NDIF, Remote Execution, Mechanistic Interpretability, Model Internals]
dependencies: [nnsight>=0.5.0, torch>=2.0.0]
---
# nnsight: Transparent Access to Neural Network Internals
nnsight (/ɛn.saɪt/) enables researchers to interpret and manipulate the internals of any PyTorch model, with the unique capability of running the same code locally on small models or remotely on massive models (70B+) via NDIF.
**GitHub**: [ndif-team/nnsight](https://github.com/ndif-team/nnsight) (730+ stars)
**Paper**: [NNsight and NDIF: Democratizing Access to Foundation Model Internals](https://arxiv.org/abs/2407.14561) (ICLR 2025)
## Key Value Proposition
**Write once, run anywhere**: The same interpretability code works on GPT-2 locally or Llama-3.1-405B remotely. Just toggle `remote=True`.
```python
# Local execution (small model)
with model.trace("Hello world"):
hidden = model.transformer.h[5].output[0].save()
# Remote execution (massive model) - same code!
with model.trace("Hello world", remote=True):
hidden = model.model.layers[40].output[0].save()
```
## When to Use nnsight
**Use nnsight when you need to:**
- Run interpretability experiments on models too large for local GPUs (70B, 405B)
- Work with any PyTorch architecture (transformers, Mamba, custom models)
- Perform multi-token generation interventions
- Share activations between different prompts
- Access full model internals without reimplementation
**Consider alternatives when:**
- You want consistent API across models → Use **TransformerLens**
- You need declarative, shareable interventions → Use **pyvene**
- You're training SAEs → Use **SAELens**
- You only work with small models locally → **TransformerLens** may be simpler
## Installation
```bash
# Basic installation
pip install nnsight
# For vLLM support
pip install "nnsight[vllm]"
```
For remote NDIF execution, sign up at [login.ndif.us](https://login.ndif.us) for an API key.
## Core Concepts
### LanguageModel Wrapper
```python
from nnsight import LanguageModel
# Load model (uses HuggingFace under the hood)
model = LanguageModel("openai-community/gpt2", device_map="auto")
# For larger models
model = LanguageModel("meta-llama/Llama-3.1-8B", device_map="auto")
```
### Tracing Context
The `trace` context manager enables deferred execution - operations are collected into a computation graph:
```python
from nnsight import LanguageModel
model = LanguageModel("gpt2", device_map="auto")
with model.trace("The Eiffel Tower is in") as tracer:
# Access any module's output
hidden_states = model.transformer.h[5].output[0].save()
# Access attention patterns
attn = model.transformer.h[5].attn.attn_dropout.input[0][0].save()
# Modify activations
model.transformer.h[8].output[0][:] = 0 # Zero out layer 8
# Get final output
logits = model.output.save()
# After context exits, access saved values
print(hidden_states.shape) # [batch, seq, hidden]
```
### Proxy Objects
Inside `trace`, module accesses return Proxy objects that record operations:
```python
with model.trace("Hello"):
# These are all Proxy objects - operations are deferred
h5_out = model.transformer.h[5].output[0] # Proxy
h5_mean = h5_out.mean(dim=-1) # Proxy
h5_saved = h5_mean.save() # Save for later access
```
## Workflow 1: Activation Analysis
### Step-by-Step
```python
from nnsight import LanguageModel
import torch
model = LanguageModel("gpt2", device_map="auto")
prompt = "The capital of France is"
with model.trace(prompt) as tracer:
# 1. Collect activations from multiple layers
layer_outputs = []
for i in range(12): # GPT-2 has 12 layers
layer_out = model.transformer.h[i].output[0].save()
layer_outputs.append(layer_out)
# 2. Get attention patterns
attn_patterns = []
for i in range(12):
# Access attention weights (after softmax)
attn = model.transformer.h[i].attn.attn_dropout.input[0][0].save()
attn_patterns.append(attn)
# 3. Get final logits
logits = model.output.save()
# 4. Analyze outside context
for i, layer_out in enumerate(layer_outputs):
print(f"Layer {i} output shape: {layer_out.shape}")
print(f"Layer {i} norm: {layer_out.norm().item():.3f}")
# 5. Find top predictions
probs = torch.softmax(logits[0, -1], dim=-1)
top_tokens = probs.topk(5)
for token, prob in zip(top_tokens.indices, top_tokens.values):
print(f"{model.tokenizer.decode(token)}: {prob.item():.3f}")
```
### Checklist
- [ ] Load model with LanguageModel wrapper
- [ ] Use trace context for operations
- [ ] Call `.save()` on values you need after context
- [ ] Access saved values outside context
- [ ] Use `.shape`, `.norm()`, etc. for analysis
## Workflow 2: Activation Patching
### Step-by-Step
```python
from nnsight import LanguageModel
import torch
model = LanguageModel("gpt2", device_map="auto")
clean_prompt = "The Eiffel Tower is in"
corrupted_prompt = "The Colosseum is in"
# 1. Get clean activations
with model.trace(clean_prompt) as tracer:
clean_hidden = model.transformer.h[8].output[0].save()
# 2. Patch clean into corrupted run
with model.trace(corrupted_prompt) as tracer:
# Replace layer 8 output with clean activations
model.transformer.h[8].output[0][:] = clean_hidden
patched_logits = model.output.save()
# 3. Compare predictions
paris_token = model.tokenizer.encode(" Paris")[0]
rome_token = model.tokenizer.encode(" Rome")[0]
patched_probs = torch.softmax(patched_logits[0, -1], dim=-1)
print(f"Paris prob: {patched_probs[paris_token].item():.3f}")
print(f"Rome prob: {patched_probs[rome_token].item():.3f}")
```
### Systematic Patching Sweep
```python
def patch_layer_position(layer, position, clean_cache, corrupted_prompt):
"""Patch single layer/position from clean to corrupted."""
with model.trace(corrupted_prompt) as tracer:
# Get current activation
current = model.transformer.h[layer].output[0]
# Patch only specific position
current[:, position, :] = clean_cache[layer][:, position, :]
logits = model.output.save()
return logits
# Sweep over all layers and positions
results = torch.zeros(12, seq_len)
for layer in range(12):
for pos in range(seq_len):
logits = patch_layer_position(layer, pos, clean_hidden, corrupted)
results[layer, pos] = compute_metric(logits)
```
## Workflow 3: Remote Execution with NDIF
Run the same experiments on massive models without local GPUs.
### Step-by-Step
```python
from nnsight import LanguageModel
# 1. Load large model (will run remotely)
model = LanguageModel("meta-llama/Llama-3.1-70B")
# 2. Same code, just add remote=True
with model.trace("The meaning of life is", remote=True) as tracer:
# Access internals of 70B model!
layer_40_out = model.model.layers[40].output[0].save()
logits = model.output.save()
# 3. Results returned from NDIF
print(f"Layer 40 shape: {layer_40_out.shape}")
# 4. Generation with interventions
with model.trace(remote=True) as tracer:
with tracer.invoke("What is 2+2?"):
# Intervene during generation
model.model.layers[20].output[0][:, -1, :] *= 1.5
output = model.generate(max_new_tokens=50)
```
### NDIF Setup
1. Sign up at [login.ndif.us](https://login.ndif.us)
2. Get API key
3. Set environment variable or pass to nnsight:
```python
import os
os.environ["NDIF_API_KEY"] = "your_key"
# Or configure directly
from nnsight import CONFIG
CONFIG.API_KEY = "your_key"
```
### Available Models on NDIF
- Llama-3.1-8B, 70B, 405B
- DeepSeek-R1 models
- Various open-weight models (check [ndif.us](https://ndif.us) for current list)
## Workflow 4: Cross-Prompt Activation Sharing
Share activations between different inputs in a single trace.
```python
from nnsight import LanguageModel
model = LanguageModel("gpt2", device_map="auto")
with model.trace() as tracer:
# First prompt
with tracer.invoke("The cat sat on the"):
cat_hidden = model.transformer.h[6].output[0].save()
# Second prompt - inject cat's activations
with tracer.invoke("The dog ran through the"):
# Replace with cat's activations at layer 6
model.transformer.h[6].output[0][:] = cat_hidden
dog_with_cat = model.output.save()
# The dog prompt now has cat's internal representations
```
## Workflow 5: Gradient-Based Analysis
Access gradients during backward pass.
```python
from nnsight import LanguageModel
import torch
model = LanguageModel("gpt2", device_map="auto")
with model.trace("The quick brown fox") as tracer:
# Save activations and enable gradient
hidden = model.transformer.h[5].output[0].save()
hidden.retain_grad()
logits = model.output
# Compute loss on specific token
target_token = model.tokenizer.encode(" jumps")[0]
loss = -logits[0, -1, target_token]
# Backward pass
loss.backward()
# Access gradients
grad = hidden.grad
print(f"Gradient shape: {grad.shape}")
print(f"Gradient norm: {grad.norm().item():.3f}")
```
**Note**: Gradient access not supported for vLLM or remote execution.
## Common Issues & Solutions
### Issue: Module path differs between models
```python
# GPT-2 structure
model.transformer.h[5].output[0]
# LLaMA structure
model.model.layers[5].output[0]
# Solution: Check model structure
print(model._model) # See actual module names
```
### Issue: Forgetting to save
```python
# WRONG: Value not accessible outside trace
with model.trace("Hello"):
hidden = model.transformer.h[5].output[0] # Not saved!
print(hidden) # Error or wrong value
# RIGHT: Call .save()
with model.trace("Hello"):
hidden = model.transformer.h[5].output[0].save()
print(hidden) # Works!
```
### Issue: Remote timeout
```python
# For long operations, increase timeout
with model.trace("prompt", remote=True, timeout=300) as tracer:
# Long operation...
```
### Issue: Memory with many saved activations
```python
# Only save what you need
with model.trace("prompt"):
# Don't save everything
for i in range(100):
model.transformer.h[i].output[0].save() # Memory heavy!
# Better: save specific layers
key_layers = [0, 5, 11]
for i in key_layers:
model.transformer.h[i].output[0].save()
```
### Issue: vLLM gradient limitation
```python
# vLLM doesn't support gradients
# Use standard execution for gradient analysis
model = LanguageModel("gpt2", device_map="auto") # Not vLLM
```
## Key API Reference
| Method/Property | Purpose |
|-----------------|---------|
| `model.trace(prompt, remote=False)` | Start tracing context |
| `proxy.save()` | Save value for access after trace |
| `proxy[:]` | Slice/index proxy (assignment patches) |
| `tracer.invoke(prompt)` | Add prompt within trace |
| `model.generate(...)` | Generate with interventions |
| `model.output` | Final model output logits |
| `model._model` | Underlying HuggingFace model |
## Comparison with Other Tools
| Feature | nnsight | TransformerLens | pyvene |
|---------|---------|-----------------|--------|
| Any architecture | Yes | Transformers only | Yes |
| Remote execution | Yes (NDIF) | No | No |
| Consistent API | No | Yes | Yes |
| Deferred execution | Yes | No | No |
| HuggingFace native | Yes | Reimplemented | Yes |
| Shareable configs | No | No | Yes |
## Reference Documentation
For detailed API documentation, tutorials, and advanced usage, see the `references/` folder:
| File | Contents |
|------|----------|
| [references/README.md](references/README.md) | Overview and quick start guide |
| [references/api.md](references/api.md) | Complete API reference for LanguageModel, tracing, proxy objects |
| [references/tutorials.md](references/tutorials.md) | Step-by-step tutorials for local and remote interpretability |
## External Resources
### Tutorials
- [Getting Started](https://nnsight.net/start/)
- [Features Overview](https://nnsight.net/features/)
- [Remote Execution](https://nnsight.net/notebooks/features/remote_execution/)
- [Applied Tutorials](https://nnsight.net/applied_tutorials/)
### Official Documentation
- [Official Docs](https://nnsight.net/documentation/)
- [NDIF Info](https://ndif.us/)
- [Community Forum](https://discuss.ndif.us/)
### Papers
- [NNsight and NDIF Paper](https://arxiv.org/abs/2407.14561) - Fiotto-Kaufman et al. (ICLR 2025)
## Architecture Support
nnsight works with any PyTorch model:
- **Transformers**: GPT-2, LLaMA, Mistral, etc.
- **State Space Models**: Mamba
- **Vision Models**: ViT, CLIP
- **Custom architectures**: Any nn.Module
The key is knowing the module structure to access the right components.
@@ -0,0 +1,78 @@
# nnsight Reference Documentation
This directory contains comprehensive reference materials for nnsight.
## Contents
- [api.md](api.md) - Complete API reference for LanguageModel, tracing, and proxy objects
- [tutorials.md](tutorials.md) - Step-by-step tutorials for local and remote interpretability
## Quick Links
- **Official Documentation**: https://nnsight.net/
- **GitHub Repository**: https://github.com/ndif-team/nnsight
- **NDIF (Remote Execution)**: https://ndif.us/
- **Community Forum**: https://discuss.ndif.us/
- **Paper**: https://arxiv.org/abs/2407.14561 (ICLR 2025)
## Installation
```bash
# Basic installation
pip install nnsight
# For vLLM support
pip install "nnsight[vllm]"
```
## Basic Usage
```python
from nnsight import LanguageModel
# Load model
model = LanguageModel("openai-community/gpt2", device_map="auto")
# Trace and access internals
with model.trace("The Eiffel Tower is in") as tracer:
# Access layer output
hidden = model.transformer.h[5].output[0].save()
# Modify activations
model.transformer.h[8].output[0][:] *= 0.5
# Get final output
logits = model.output.save()
# Access saved values outside context
print(hidden.shape)
```
## Key Concepts
### Tracing
The `trace()` context enables deferred execution - operations are recorded and executed together.
### Proxy Objects
Inside trace, module accesses return Proxies. Call `.save()` to retrieve values after execution.
### Remote Execution (NDIF)
Run the same code on massive models (70B+) without local GPUs:
```python
# Same code, just add remote=True
with model.trace("Hello", remote=True):
hidden = model.model.layers[40].output[0].save()
```
## NDIF Setup
1. Sign up at https://login.ndif.us/
2. Get API key
3. Set environment variable: `export NDIF_API_KEY=your_key`
## Available Remote Models
- Llama-3.1-8B, 70B, 405B
- DeepSeek-R1 models
- More at https://ndif.us/
@@ -0,0 +1,344 @@
# nnsight API Reference
## LanguageModel
Main class for wrapping language models with intervention capabilities.
### Loading Models
```python
from nnsight import LanguageModel
# Basic loading
model = LanguageModel("openai-community/gpt2", device_map="auto")
# Larger models
model = LanguageModel("meta-llama/Llama-3.1-8B", device_map="auto")
# With custom tokenizer settings
model = LanguageModel(
"gpt2",
device_map="auto",
torch_dtype=torch.float16,
)
```
### Model Attributes
```python
# Access underlying HuggingFace model
model._model
# Access tokenizer
model.tokenizer
# Model config
model._model.config
```
---
## Tracing Context
The `trace()` method creates a context for deferred execution.
### Basic Tracing
```python
with model.trace("Hello world") as tracer:
# Operations are recorded, not executed immediately
hidden = model.transformer.h[5].output[0].save()
logits = model.output.save()
# After context, operations execute and saved values are available
print(hidden.shape)
```
### Tracing Parameters
```python
with model.trace(
prompt, # Input text or tokens
remote=False, # Use NDIF remote execution
validate=True, # Validate tensor shapes
scan=True, # Scan for shape info
) as tracer:
...
```
### Remote Execution
```python
# Same code works remotely
with model.trace("Hello", remote=True) as tracer:
hidden = model.transformer.h[5].output[0].save()
```
---
## Proxy Objects
Inside tracing context, accessing modules returns Proxy objects.
### Accessing Values
```python
with model.trace("Hello") as tracer:
# These are Proxy objects
layer_output = model.transformer.h[5].output[0]
attention = model.transformer.h[5].attn.output
# Operations create new Proxies
mean = layer_output.mean(dim=-1)
normed = layer_output / layer_output.norm()
```
### Saving Values
```python
with model.trace("Hello") as tracer:
# Must call .save() to access after context
hidden = model.transformer.h[5].output[0].save()
# Now hidden contains actual tensor
print(hidden.shape)
```
### Modifying Values
```python
with model.trace("Hello") as tracer:
# In-place modification
model.transformer.h[5].output[0][:] = 0
# Replace with computed value
model.transformer.h[5].output[0][:] = some_tensor
# Arithmetic modification
model.transformer.h[5].output[0][:] *= 0.5
model.transformer.h[5].output[0][:] += steering_vector
```
### Proxy Operations
```python
with model.trace("Hello") as tracer:
h = model.transformer.h[5].output[0]
# Indexing
first_token = h[:, 0, :]
last_token = h[:, -1, :]
# PyTorch operations
mean = h.mean(dim=-1)
norm = h.norm()
transposed = h.transpose(1, 2)
# Save results
mean.save()
```
---
## Module Access Patterns
### GPT-2 Structure
```python
with model.trace("Hello") as tracer:
# Embeddings
embed = model.transformer.wte.output.save()
pos_embed = model.transformer.wpe.output.save()
# Layer outputs
layer_out = model.transformer.h[5].output[0].save()
# Attention
attn_out = model.transformer.h[5].attn.output.save()
# MLP
mlp_out = model.transformer.h[5].mlp.output.save()
# Final output
logits = model.output.save()
```
### LLaMA Structure
```python
with model.trace("Hello") as tracer:
# Embeddings
embed = model.model.embed_tokens.output.save()
# Layer outputs
layer_out = model.model.layers[10].output[0].save()
# Attention
attn_out = model.model.layers[10].self_attn.output.save()
# MLP
mlp_out = model.model.layers[10].mlp.output.save()
# Final output
logits = model.output.save()
```
### Finding Module Names
```python
# Print model structure
print(model._model)
# Or iterate
for name, module in model._model.named_modules():
print(name)
```
---
## Multiple Prompts (invoke)
Process multiple prompts in a single trace.
### Basic Usage
```python
with model.trace() as tracer:
with tracer.invoke("First prompt"):
hidden1 = model.transformer.h[5].output[0].save()
with tracer.invoke("Second prompt"):
hidden2 = model.transformer.h[5].output[0].save()
```
### Cross-Prompt Intervention
```python
with model.trace() as tracer:
# Get activations from first prompt
with tracer.invoke("The cat sat on the"):
cat_hidden = model.transformer.h[6].output[0].save()
# Inject into second prompt
with tracer.invoke("The dog ran through the"):
model.transformer.h[6].output[0][:] = cat_hidden
output = model.output.save()
```
---
## Generation
Generate text with interventions.
### Basic Generation
```python
with model.trace() as tracer:
with tracer.invoke("Once upon a time"):
# Intervention during generation
model.transformer.h[5].output[0][:] *= 1.2
output = model.generate(max_new_tokens=50)
print(model.tokenizer.decode(output[0]))
```
---
## Gradients
Access gradients for analysis (not supported with remote/vLLM).
```python
with model.trace("The quick brown fox") as tracer:
hidden = model.transformer.h[5].output[0].save()
hidden.retain_grad()
logits = model.output
target_token = model.tokenizer.encode(" jumps")[0]
loss = -logits[0, -1, target_token]
loss.backward()
# Access gradient
grad = hidden.grad
```
---
## NDIF Remote Execution
### Setup
```python
import os
os.environ["NDIF_API_KEY"] = "your_key"
# Or configure directly
from nnsight import CONFIG
CONFIG.set_default_api_key("your_key")
```
### Using Remote
```python
model = LanguageModel("meta-llama/Llama-3.1-70B")
with model.trace("Hello", remote=True) as tracer:
hidden = model.model.layers[40].output[0].save()
logits = model.output.save()
# Results returned from NDIF
print(hidden.shape)
```
### Sessions (Batching Requests)
```python
with model.session(remote=True) as session:
with model.trace("First prompt"):
h1 = model.model.layers[20].output[0].save()
with model.trace("Second prompt"):
h2 = model.model.layers[20].output[0].save()
# Both run in single NDIF request
```
---
## Utility Methods
### Early Stopping
```python
with model.trace("Hello") as tracer:
hidden = model.transformer.h[5].output[0].save()
tracer.stop() # Don't run remaining layers
```
### Validation
```python
# Validate shapes before execution
with model.trace("Hello", validate=True) as tracer:
hidden = model.transformer.h[5].output[0].save()
```
### Module Access Result
```python
with model.trace("Hello") as tracer:
# Access result of a method call
result = tracer.result
```
---
## Common Module Paths
| Model | Embeddings | Layers | Attention | MLP |
|-------|------------|--------|-----------|-----|
| GPT-2 | `transformer.wte` | `transformer.h[i]` | `transformer.h[i].attn` | `transformer.h[i].mlp` |
| LLaMA | `model.embed_tokens` | `model.layers[i]` | `model.layers[i].self_attn` | `model.layers[i].mlp` |
| Mistral | `model.embed_tokens` | `model.layers[i]` | `model.layers[i].self_attn` | `model.layers[i].mlp` |
@@ -0,0 +1,300 @@
# nnsight Tutorials
## Tutorial 1: Basic Activation Analysis
### Goal
Load a model, access internal activations, and analyze them.
### Step-by-Step
```python
from nnsight import LanguageModel
import torch
# 1. Load model
model = LanguageModel("openai-community/gpt2", device_map="auto")
# 2. Trace and collect activations
prompt = "The capital of France is"
with model.trace(prompt) as tracer:
# Collect from multiple layers
activations = {}
for i in range(12): # GPT-2 has 12 layers
activations[i] = model.transformer.h[i].output[0].save()
# Get final logits
logits = model.output.save()
# 3. Analyze (outside context)
print("Layer-wise activation norms:")
for layer, act in activations.items():
print(f" Layer {layer}: {act.norm().item():.2f}")
# 4. Check predictions
probs = torch.softmax(logits[0, -1], dim=-1)
top_tokens = probs.topk(5)
print("\nTop predictions:")
for token_id, prob in zip(top_tokens.indices, top_tokens.values):
token_str = model.tokenizer.decode(token_id)
print(f" {token_str!r}: {prob.item():.3f}")
```
---
## Tutorial 2: Activation Patching
### Goal
Patch activations from one prompt into another to test causal relationships.
### Step-by-Step
```python
from nnsight import LanguageModel
import torch
model = LanguageModel("gpt2", device_map="auto")
clean_prompt = "The Eiffel Tower is in the city of"
corrupted_prompt = "The Colosseum is in the city of"
# 1. Get clean activations
with model.trace(clean_prompt) as tracer:
clean_hidden = model.transformer.h[8].output[0].save()
clean_logits = model.output.save()
# 2. Define metric
paris_token = model.tokenizer.encode(" Paris")[0]
rome_token = model.tokenizer.encode(" Rome")[0]
def logit_diff(logits):
return (logits[0, -1, paris_token] - logits[0, -1, rome_token]).item()
print(f"Clean logit diff: {logit_diff(clean_logits):.3f}")
# 3. Patch clean into corrupted
with model.trace(corrupted_prompt) as tracer:
# Replace layer 8 output with clean activations
model.transformer.h[8].output[0][:] = clean_hidden
patched_logits = model.output.save()
print(f"Patched logit diff: {logit_diff(patched_logits):.3f}")
# 4. Systematic patching sweep
results = torch.zeros(12) # 12 layers
for layer in range(12):
# Get clean activation for this layer
with model.trace(clean_prompt) as tracer:
clean_act = model.transformer.h[layer].output[0].save()
# Patch into corrupted
with model.trace(corrupted_prompt) as tracer:
model.transformer.h[layer].output[0][:] = clean_act
logits = model.output.save()
results[layer] = logit_diff(logits)
print(f"Layer {layer}: {results[layer]:.3f}")
print(f"\nMost important layer: {results.argmax().item()}")
```
---
## Tutorial 3: Cross-Prompt Activation Sharing
### Goal
Transfer activations between different prompts in a single trace.
### Step-by-Step
```python
from nnsight import LanguageModel
model = LanguageModel("gpt2", device_map="auto")
with model.trace() as tracer:
# First prompt - get "cat" representations
with tracer.invoke("The cat sat on the mat"):
cat_hidden = model.transformer.h[6].output[0].save()
# Second prompt - inject "cat" into "dog"
with tracer.invoke("The dog ran through the park"):
# Replace with cat's activations
model.transformer.h[6].output[0][:] = cat_hidden
modified_logits = model.output.save()
# The dog prompt now has cat's internal representations
print(f"Modified logits shape: {modified_logits.shape}")
```
---
## Tutorial 4: Remote Execution with NDIF
### Goal
Run the same interpretability code on massive models (70B+).
### Step-by-Step
```python
from nnsight import LanguageModel
import os
# 1. Setup API key
os.environ["NDIF_API_KEY"] = "your_key_here"
# 2. Load large model (runs remotely)
model = LanguageModel("meta-llama/Llama-3.1-70B")
# 3. Same code, just remote=True
prompt = "The meaning of life is"
with model.trace(prompt, remote=True) as tracer:
# Access layer 40 of 70B model!
hidden = model.model.layers[40].output[0].save()
logits = model.output.save()
# 4. Results returned from NDIF
print(f"Hidden shape: {hidden.shape}")
print(f"Logits shape: {logits.shape}")
# 5. Check predictions
import torch
probs = torch.softmax(logits[0, -1], dim=-1)
top_tokens = probs.topk(5)
print("\nTop predictions from Llama-70B:")
for token_id, prob in zip(top_tokens.indices, top_tokens.values):
print(f" {model.tokenizer.decode(token_id)!r}: {prob.item():.3f}")
```
### Batching with Sessions
```python
# Run multiple experiments in one NDIF request
with model.session(remote=True) as session:
with model.trace("What is 2+2?"):
math_hidden = model.model.layers[30].output[0].save()
with model.trace("The capital of France is"):
fact_hidden = model.model.layers[30].output[0].save()
# Compare representations
similarity = torch.cosine_similarity(
math_hidden.mean(dim=1),
fact_hidden.mean(dim=1),
dim=-1
)
print(f"Similarity: {similarity.item():.3f}")
```
---
## Tutorial 5: Steering with Activation Addition
### Goal
Add a steering vector to change model behavior.
### Step-by-Step
```python
from nnsight import LanguageModel
import torch
model = LanguageModel("gpt2", device_map="auto")
# 1. Get contrasting activations
with model.trace("I love this movie, it's wonderful") as tracer:
positive_hidden = model.transformer.h[6].output[0].save()
with model.trace("I hate this movie, it's terrible") as tracer:
negative_hidden = model.transformer.h[6].output[0].save()
# 2. Compute steering direction
steering_vector = positive_hidden.mean(dim=1) - negative_hidden.mean(dim=1)
# 3. Generate without steering
test_prompt = "This restaurant is"
with model.trace(test_prompt) as tracer:
normal_logits = model.output.save()
# 4. Generate with steering
with model.trace(test_prompt) as tracer:
# Add steering at layer 6
model.transformer.h[6].output[0][:] += 3.0 * steering_vector
steered_logits = model.output.save()
# 5. Compare predictions
def top_prediction(logits):
token = logits[0, -1].argmax()
return model.tokenizer.decode(token)
print(f"Normal: {top_prediction(normal_logits)}")
print(f"Steered (positive): {top_prediction(steered_logits)}")
```
---
## Tutorial 6: Logit Lens
### Goal
See what the model "believes" at each layer.
### Step-by-Step
```python
from nnsight import LanguageModel
import torch
model = LanguageModel("gpt2", device_map="auto")
prompt = "The quick brown fox jumps over the lazy"
with model.trace(prompt) as tracer:
# Collect residual stream at each layer
residuals = []
for i in range(12):
resid = model.transformer.h[i].output[0].save()
residuals.append(resid)
# Access model's unembedding and final layernorm
W_U = model._model.lm_head.weight.T # [d_model, vocab]
ln_f = model._model.transformer.ln_f
print("Layer-by-layer predictions for final token:")
for i, resid in enumerate(residuals):
# Apply final layernorm
normed = ln_f(resid)
# Project to vocabulary
layer_logits = normed @ W_U
# Get prediction
probs = torch.softmax(layer_logits[0, -1], dim=-1)
top_token = probs.argmax()
top_prob = probs[top_token].item()
print(f"Layer {i}: {model.tokenizer.decode(top_token)!r} ({top_prob:.3f})")
```
---
## External Resources
### Official Resources
- [Getting Started](https://nnsight.net/start/)
- [Features Overview](https://nnsight.net/features/)
- [Documentation](https://nnsight.net/documentation/)
- [Tutorials](https://nnsight.net/tutorials/)
### NDIF Resources
- [NDIF Homepage](https://ndif.us/)
- [Available Models](https://ndif.us/models)
- [API Key Signup](https://login.ndif.us/)
### Paper
- [NNsight and NDIF](https://arxiv.org/abs/2407.14561) - ICLR 2025
### Community
- [Discussion Forum](https://discuss.ndif.us/)
- [GitHub Issues](https://github.com/ndif-team/nnsight/issues)
@@ -0,0 +1,473 @@
---
name: pyvene-interventions
description: Provides guidance for performing causal interventions on PyTorch models using pyvene's declarative intervention framework. Use when conducting causal tracing, activation patching, interchange intervention training, or testing causal hypotheses about model behavior.
version: 1.0.0
author: Orchestra Research
license: MIT
tags: [Causal Intervention, pyvene, Activation Patching, Causal Tracing, Interpretability]
dependencies: [pyvene>=0.1.8, torch>=2.0.0, transformers>=4.30.0]
---
# pyvene: Causal Interventions for Neural Networks
pyvene is Stanford NLP's library for performing causal interventions on PyTorch models. It provides a declarative, dict-based framework for activation patching, causal tracing, and interchange intervention training - making intervention experiments reproducible and shareable.
**GitHub**: [stanfordnlp/pyvene](https://github.com/stanfordnlp/pyvene) (840+ stars)
**Paper**: [pyvene: A Library for Understanding and Improving PyTorch Models via Interventions](https://aclanthology.org/2024.naacl-demo.16) (NAACL 2024)
## When to Use pyvene
**Use pyvene when you need to:**
- Perform causal tracing (ROME-style localization)
- Run activation patching experiments
- Conduct interchange intervention training (IIT)
- Test causal hypotheses about model components
- Share/reproduce intervention experiments via HuggingFace
- Work with any PyTorch architecture (not just transformers)
**Consider alternatives when:**
- You need exploratory activation analysis → Use **TransformerLens**
- You want to train/analyze SAEs → Use **SAELens**
- You need remote execution on massive models → Use **nnsight**
- You want lower-level control → Use **nnsight**
## Installation
```bash
pip install pyvene
```
Standard import:
```python
import pyvene as pv
```
## Core Concepts
### IntervenableModel
The main class that wraps any PyTorch model with intervention capabilities:
```python
import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load base model
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# Define intervention configuration
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=8,
component="block_output",
intervention_type=pv.VanillaIntervention,
)
]
)
# Create intervenable model
intervenable = pv.IntervenableModel(config, model)
```
### Intervention Types
| Type | Description | Use Case |
|------|-------------|----------|
| `VanillaIntervention` | Swap activations between runs | Activation patching |
| `AdditionIntervention` | Add activations to base run | Steering, ablation |
| `SubtractionIntervention` | Subtract activations | Ablation |
| `ZeroIntervention` | Zero out activations | Component knockout |
| `RotatedSpaceIntervention` | DAS trainable intervention | Causal discovery |
| `CollectIntervention` | Collect activations | Probing, analysis |
### Component Targets
```python
# Available components to intervene on
components = [
"block_input", # Input to transformer block
"block_output", # Output of transformer block
"mlp_input", # Input to MLP
"mlp_output", # Output of MLP
"mlp_activation", # MLP hidden activations
"attention_input", # Input to attention
"attention_output", # Output of attention
"attention_value_output", # Attention value vectors
"query_output", # Query vectors
"key_output", # Key vectors
"value_output", # Value vectors
"head_attention_value_output", # Per-head values
]
```
## Workflow 1: Causal Tracing (ROME-style)
Locate where factual associations are stored by corrupting inputs and restoring activations.
### Step-by-Step
```python
import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model = AutoModelForCausalLM.from_pretrained("gpt2-xl")
tokenizer = AutoTokenizer.from_pretrained("gpt2-xl")
# 1. Define clean and corrupted inputs
clean_prompt = "The Space Needle is in downtown"
corrupted_prompt = "The ##### ###### ## ## ########" # Noise
clean_tokens = tokenizer(clean_prompt, return_tensors="pt")
corrupted_tokens = tokenizer(corrupted_prompt, return_tensors="pt")
# 2. Get clean activations (source)
with torch.no_grad():
clean_outputs = model(**clean_tokens, output_hidden_states=True)
clean_states = clean_outputs.hidden_states
# 3. Define restoration intervention
def run_causal_trace(layer, position):
"""Restore clean activation at specific layer and position."""
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=layer,
component="block_output",
intervention_type=pv.VanillaIntervention,
unit="pos",
max_number_of_units=1,
)
]
)
intervenable = pv.IntervenableModel(config, model)
# Run with intervention
_, patched_outputs = intervenable(
base=corrupted_tokens,
sources=[clean_tokens],
unit_locations={"sources->base": ([[[position]]], [[[position]]])},
output_original_output=True,
)
# Return probability of correct token
probs = torch.softmax(patched_outputs.logits[0, -1], dim=-1)
seattle_token = tokenizer.encode(" Seattle")[0]
return probs[seattle_token].item()
# 4. Sweep over layers and positions
n_layers = model.config.n_layer
seq_len = clean_tokens["input_ids"].shape[1]
results = torch.zeros(n_layers, seq_len)
for layer in range(n_layers):
for pos in range(seq_len):
results[layer, pos] = run_causal_trace(layer, pos)
# 5. Visualize (layer x position heatmap)
# High values indicate causal importance
```
### Checklist
- [ ] Prepare clean prompt with target factual association
- [ ] Create corrupted version (noise or counterfactual)
- [ ] Define intervention config for each (layer, position)
- [ ] Run patching sweep
- [ ] Identify causal hotspots in heatmap
## Workflow 2: Activation Patching for Circuit Analysis
Test which components are necessary for a specific behavior.
### Step-by-Step
```python
import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# IOI task setup
clean_prompt = "When John and Mary went to the store, Mary gave a bottle to"
corrupted_prompt = "When John and Mary went to the store, John gave a bottle to"
clean_tokens = tokenizer(clean_prompt, return_tensors="pt")
corrupted_tokens = tokenizer(corrupted_prompt, return_tensors="pt")
john_token = tokenizer.encode(" John")[0]
mary_token = tokenizer.encode(" Mary")[0]
def logit_diff(logits):
"""IO - S logit difference."""
return logits[0, -1, john_token] - logits[0, -1, mary_token]
# Patch attention output at each layer
def patch_attention(layer):
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=layer,
component="attention_output",
intervention_type=pv.VanillaIntervention,
)
]
)
intervenable = pv.IntervenableModel(config, model)
_, patched_outputs = intervenable(
base=corrupted_tokens,
sources=[clean_tokens],
)
return logit_diff(patched_outputs.logits).item()
# Find which layers matter
results = []
for layer in range(model.config.n_layer):
diff = patch_attention(layer)
results.append(diff)
print(f"Layer {layer}: logit diff = {diff:.3f}")
```
## Workflow 3: Interchange Intervention Training (IIT)
Train interventions to discover causal structure.
### Step-by-Step
```python
import pyvene as pv
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("gpt2")
# 1. Define trainable intervention
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=6,
component="block_output",
intervention_type=pv.RotatedSpaceIntervention, # Trainable
low_rank_dimension=64, # Learn 64-dim subspace
)
]
)
intervenable = pv.IntervenableModel(config, model)
# 2. Set up training
optimizer = torch.optim.Adam(
intervenable.get_trainable_parameters(),
lr=1e-4
)
# 3. Training loop (simplified)
for base_input, source_input, target_output in dataloader:
optimizer.zero_grad()
_, outputs = intervenable(
base=base_input,
sources=[source_input],
)
loss = criterion(outputs.logits, target_output)
loss.backward()
optimizer.step()
# 4. Analyze learned intervention
# The rotation matrix reveals causal subspace
rotation = intervenable.interventions["layer.6.block_output"][0].rotate_layer
```
### DAS (Distributed Alignment Search)
```python
# Low-rank rotation finds interpretable subspaces
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=8,
component="block_output",
intervention_type=pv.LowRankRotatedSpaceIntervention,
low_rank_dimension=1, # Find 1D causal direction
)
]
)
```
## Workflow 4: Model Steering (Honest LLaMA)
Steer model behavior during generation.
```python
import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
# Load pre-trained steering intervention
intervenable = pv.IntervenableModel.load(
"zhengxuanzenwu/intervenable_honest_llama2_chat_7B",
model=model,
)
# Generate with steering
prompt = "Is the earth flat?"
inputs = tokenizer(prompt, return_tensors="pt")
# Intervention applied during generation
outputs = intervenable.generate(
inputs,
max_new_tokens=100,
do_sample=False,
)
print(tokenizer.decode(outputs[0]))
```
## Saving and Sharing Interventions
```python
# Save locally
intervenable.save("./my_intervention")
# Load from local
intervenable = pv.IntervenableModel.load(
"./my_intervention",
model=model,
)
# Share on HuggingFace
intervenable.save_intervention("username/my-intervention")
# Load from HuggingFace
intervenable = pv.IntervenableModel.load(
"username/my-intervention",
model=model,
)
```
## Common Issues & Solutions
### Issue: Wrong intervention location
```python
# WRONG: Incorrect component name
config = pv.RepresentationConfig(
component="mlp", # Not valid!
)
# RIGHT: Use exact component name
config = pv.RepresentationConfig(
component="mlp_output", # Valid
)
```
### Issue: Dimension mismatch
```python
# Ensure source and base have compatible shapes
# For position-specific interventions:
config = pv.RepresentationConfig(
unit="pos",
max_number_of_units=1, # Intervene on single position
)
# Specify locations explicitly
intervenable(
base=base_tokens,
sources=[source_tokens],
unit_locations={"sources->base": ([[[5]]], [[[5]]])}, # Position 5
)
```
### Issue: Memory with large models
```python
# Use gradient checkpointing
model.gradient_checkpointing_enable()
# Or intervene on fewer components
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=8, # Single layer instead of all
component="block_output",
)
]
)
```
### Issue: LoRA integration
```python
# pyvene v0.1.8+ supports LoRAs as interventions
config = pv.RepresentationConfig(
intervention_type=pv.LoRAIntervention,
low_rank_dimension=16,
)
```
## Key Classes Reference
| Class | Purpose |
|-------|---------|
| `IntervenableModel` | Main wrapper for interventions |
| `IntervenableConfig` | Configuration container |
| `RepresentationConfig` | Single intervention specification |
| `VanillaIntervention` | Activation swapping |
| `RotatedSpaceIntervention` | Trainable DAS intervention |
| `CollectIntervention` | Activation collection |
## Supported Models
pyvene works with any PyTorch model. Tested on:
- GPT-2 (all sizes)
- LLaMA / LLaMA-2
- Pythia
- Mistral / Mixtral
- OPT
- BLIP (vision-language)
- ESM (protein models)
- Mamba (state space)
## Reference Documentation
For detailed API documentation, tutorials, and advanced usage, see the `references/` folder:
| File | Contents |
|------|----------|
| [references/README.md](references/README.md) | Overview and quick start guide |
| [references/api.md](references/api.md) | Complete API reference for IntervenableModel, intervention types, configurations |
| [references/tutorials.md](references/tutorials.md) | Step-by-step tutorials for causal tracing, activation patching, DAS |
## External Resources
### Tutorials
- [pyvene 101](https://stanfordnlp.github.io/pyvene/tutorials/pyvene_101.html)
- [Causal Tracing Tutorial](https://stanfordnlp.github.io/pyvene/tutorials/advanced_tutorials/Causal_Tracing.html)
- [IOI Circuit Replication](https://stanfordnlp.github.io/pyvene/tutorials/advanced_tutorials/IOI_Replication.html)
- [DAS Introduction](https://stanfordnlp.github.io/pyvene/tutorials/advanced_tutorials/DAS_Main_Introduction.html)
### Papers
- [Locating and Editing Factual Associations in GPT](https://arxiv.org/abs/2202.05262) - Meng et al. (2022)
- [Inference-Time Intervention](https://arxiv.org/abs/2306.03341) - Li et al. (2023)
- [Interpretability in the Wild](https://arxiv.org/abs/2211.00593) - Wang et al. (2022)
### Official Documentation
- [Official Docs](https://stanfordnlp.github.io/pyvene/)
- [API Reference](https://stanfordnlp.github.io/pyvene/api/)
## Comparison with Other Tools
| Feature | pyvene | TransformerLens | nnsight |
|---------|--------|-----------------|---------|
| Declarative config | Yes | No | No |
| HuggingFace sharing | Yes | No | No |
| Trainable interventions | Yes | Limited | Yes |
| Any PyTorch model | Yes | Transformers only | Yes |
| Remote execution | No | No | Yes (NDIF) |
@@ -0,0 +1,73 @@
# pyvene Reference Documentation
This directory contains comprehensive reference materials for pyvene.
## Contents
- [api.md](api.md) - Complete API reference for IntervenableModel, intervention types, and configurations
- [tutorials.md](tutorials.md) - Step-by-step tutorials for causal tracing, activation patching, and trainable interventions
## Quick Links
- **Official Documentation**: https://stanfordnlp.github.io/pyvene/
- **GitHub Repository**: https://github.com/stanfordnlp/pyvene
- **Paper**: https://arxiv.org/abs/2403.07809 (NAACL 2024)
## Installation
```bash
pip install pyvene
```
## Basic Usage
```python
import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load model
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# Define intervention
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=5,
component="block_output",
intervention_type=pv.VanillaIntervention,
)
]
)
# Create intervenable model
intervenable = pv.IntervenableModel(config, model)
# Run intervention (swap activations from source to base)
base_inputs = tokenizer("The cat sat on the", return_tensors="pt")
source_inputs = tokenizer("The dog ran through the", return_tensors="pt")
_, outputs = intervenable(
base=base_inputs,
sources=[source_inputs],
)
```
## Key Concepts
### Intervention Types
- **VanillaIntervention**: Swap activations between runs
- **AdditionIntervention**: Add source to base activations
- **ZeroIntervention**: Zero out activations (ablation)
- **CollectIntervention**: Collect activations without modifying
- **RotatedSpaceIntervention**: Trainable intervention for causal discovery
### Components
Target specific parts of the model:
- `block_input`, `block_output`
- `mlp_input`, `mlp_output`, `mlp_activation`
- `attention_input`, `attention_output`
- `query_output`, `key_output`, `value_output`
### HuggingFace Integration
Save and load interventions via HuggingFace Hub for reproducibility.
@@ -0,0 +1,383 @@
# pyvene API Reference
## IntervenableModel
The core class that wraps PyTorch models for intervention.
### Basic Usage
```python
import pyvene as pv
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("gpt2")
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=5,
component="block_output",
intervention_type=pv.VanillaIntervention,
)
]
)
intervenable = pv.IntervenableModel(config, model)
```
### Forward Pass
```python
# Basic intervention
original_output, intervened_output = intervenable(
base=base_inputs,
sources=[source_inputs],
)
# With unit locations (position-specific)
_, outputs = intervenable(
base=base_inputs,
sources=[source_inputs],
unit_locations={"sources->base": ([[[5]]], [[[5]]])}, # Position 5
)
# Return original output too
original, intervened = intervenable(
base=base_inputs,
sources=[source_inputs],
output_original_output=True,
)
```
### Generation
```python
# Generate with interventions
outputs = intervenable.generate(
base_inputs,
sources=[source_inputs],
max_new_tokens=50,
do_sample=False,
)
```
### Saving and Loading
```python
# Save locally
intervenable.save("./my_intervention")
# Load
intervenable = pv.IntervenableModel.load("./my_intervention", model=model)
# Save to HuggingFace
intervenable.save_intervention("username/my-intervention")
# Load from HuggingFace
intervenable = pv.IntervenableModel.load(
"username/my-intervention",
model=model
)
```
### Getting Trainable Parameters
```python
# For trainable interventions
params = intervenable.get_trainable_parameters()
optimizer = torch.optim.Adam(params, lr=1e-4)
```
---
## IntervenableConfig
Configuration container for interventions.
### Basic Config
```python
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(...)
]
)
```
### Multiple Interventions
```python
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(layer=3, component="block_output", ...),
pv.RepresentationConfig(layer=5, component="mlp_output", ...),
pv.RepresentationConfig(layer=7, component="attention_output", ...),
]
)
```
---
## RepresentationConfig
Specifies a single intervention target.
### Parameters
| Parameter | Type | Description |
|-----------|------|-------------|
| `layer` | int | Layer index |
| `component` | str | Component to intervene on |
| `intervention_type` | type | Intervention class |
| `unit` | str | Intervention unit ("pos", "h", etc.) |
| `max_number_of_units` | int | Max units to intervene |
| `low_rank_dimension` | int | For trainable interventions |
| `subspace_partition` | list | Dimension ranges |
### Components
| Component | Description |
|-----------|-------------|
| `block_input` | Input to transformer block |
| `block_output` | Output of transformer block |
| `mlp_input` | Input to MLP |
| `mlp_output` | Output of MLP |
| `mlp_activation` | MLP hidden activations |
| `attention_input` | Input to attention |
| `attention_output` | Output of attention |
| `attention_value_output` | Attention values |
| `query_output` | Query vectors |
| `key_output` | Key vectors |
| `value_output` | Value vectors |
| `head_attention_value_output` | Per-head values |
### Example Configs
```python
# Position-specific intervention
pv.RepresentationConfig(
layer=5,
component="block_output",
intervention_type=pv.VanillaIntervention,
unit="pos",
max_number_of_units=1,
)
# Trainable low-rank intervention
pv.RepresentationConfig(
layer=5,
component="block_output",
intervention_type=pv.LowRankRotatedSpaceIntervention,
low_rank_dimension=64,
)
# Subspace intervention
pv.RepresentationConfig(
layer=5,
component="block_output",
intervention_type=pv.VanillaIntervention,
subspace_partition=[[0, 256], [256, 512]], # First 512 dims split
)
```
---
## Intervention Types
### Basic Interventions
#### VanillaIntervention
Replaces base activations with source activations.
```python
pv.RepresentationConfig(
intervention_type=pv.VanillaIntervention,
...
)
```
#### AdditionIntervention
Adds source activations to base.
```python
pv.RepresentationConfig(
intervention_type=pv.AdditionIntervention,
...
)
```
#### SubtractionIntervention
Subtracts source from base.
```python
pv.RepresentationConfig(
intervention_type=pv.SubtractionIntervention,
...
)
```
#### ZeroIntervention
Sets activations to zero (ablation).
```python
pv.RepresentationConfig(
intervention_type=pv.ZeroIntervention,
...
)
```
#### CollectIntervention
Collects activations without modification.
```python
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=5,
component="block_output",
intervention_type=pv.CollectIntervention,
)
]
)
intervenable = pv.IntervenableModel(config, model)
_, collected = intervenable(base=inputs)
# collected contains the activations
```
### Trainable Interventions
#### RotatedSpaceIntervention
Full-rank trainable rotation.
```python
pv.RepresentationConfig(
intervention_type=pv.RotatedSpaceIntervention,
...
)
```
#### LowRankRotatedSpaceIntervention
Low-rank trainable intervention (DAS).
```python
pv.RepresentationConfig(
intervention_type=pv.LowRankRotatedSpaceIntervention,
low_rank_dimension=64,
...
)
```
#### BoundlessRotatedSpaceIntervention
Boundless DAS variant.
```python
pv.RepresentationConfig(
intervention_type=pv.BoundlessRotatedSpaceIntervention,
...
)
```
#### SigmoidMaskIntervention
Learnable binary mask.
```python
pv.RepresentationConfig(
intervention_type=pv.SigmoidMaskIntervention,
...
)
```
---
## Unit Locations
Specify exactly where to intervene.
### Format
```python
unit_locations = {
"sources->base": (source_locations, base_locations)
}
```
### Examples
```python
# Single position
unit_locations = {"sources->base": ([[[5]]], [[[5]]])}
# Multiple positions
unit_locations = {"sources->base": ([[[3, 5, 7]]], [[[3, 5, 7]]])}
# Different source and base positions
unit_locations = {"sources->base": ([[[5]]], [[[10]]])}
```
---
## Supported Models
pyvene works with any PyTorch model. Officially tested:
| Family | Models |
|--------|--------|
| GPT-2 | gpt2, gpt2-medium, gpt2-large, gpt2-xl |
| LLaMA | llama-7b, llama-2-7b, llama-2-13b |
| Pythia | pythia-70m to pythia-12b |
| Mistral | mistral-7b, mixtral-8x7b |
| Gemma | gemma-2b, gemma-7b |
| Vision | BLIP, LLaVA |
| Other | OPT, Phi, Qwen, ESM, Mamba |
---
## Quick Reference: Common Patterns
### Activation Patching
```python
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=layer,
component="block_output",
intervention_type=pv.VanillaIntervention,
)
]
)
```
### Causal Tracing (ROME-style)
```python
config = pv.IntervenableConfig(
representations=[
# First corrupt with noise
pv.RepresentationConfig(
layer=0,
component="block_input",
intervention_type=pv.NoiseIntervention,
),
# Then restore at target layer
pv.RepresentationConfig(
layer=target_layer,
component="block_output",
intervention_type=pv.VanillaIntervention,
),
]
)
```
### DAS (Distributed Alignment Search)
```python
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=layer,
component="block_output",
intervention_type=pv.LowRankRotatedSpaceIntervention,
low_rank_dimension=1, # Find 1D causal direction
)
]
)
```
@@ -0,0 +1,376 @@
# pyvene Tutorials
## Tutorial 1: Basic Activation Patching
### Goal
Swap activations between two prompts to test causal relationships.
### Step-by-Step
```python
import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# 1. Load model
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# 2. Prepare inputs
base_prompt = "The Colosseum is in the city of"
source_prompt = "The Eiffel Tower is in the city of"
base_inputs = tokenizer(base_prompt, return_tensors="pt")
source_inputs = tokenizer(source_prompt, return_tensors="pt")
# 3. Define intervention (patch layer 8)
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=8,
component="block_output",
intervention_type=pv.VanillaIntervention,
)
]
)
intervenable = pv.IntervenableModel(config, model)
# 4. Run intervention
_, patched_outputs = intervenable(
base=base_inputs,
sources=[source_inputs],
)
# 5. Check predictions
patched_logits = patched_outputs.logits
probs = torch.softmax(patched_logits[0, -1], dim=-1)
rome_token = tokenizer.encode(" Rome")[0]
paris_token = tokenizer.encode(" Paris")[0]
print(f"P(Rome): {probs[rome_token].item():.4f}")
print(f"P(Paris): {probs[paris_token].item():.4f}")
```
---
## Tutorial 2: Causal Tracing (ROME-style)
### Goal
Locate where factual associations are stored by corrupting inputs and restoring activations.
### Step-by-Step
```python
import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model = AutoModelForCausalLM.from_pretrained("gpt2-xl")
tokenizer = AutoTokenizer.from_pretrained("gpt2-xl")
# 1. Define prompts
clean_prompt = "The Space Needle is in downtown"
# We'll corrupt by adding noise to embeddings
clean_inputs = tokenizer(clean_prompt, return_tensors="pt")
seattle_token = tokenizer.encode(" Seattle")[0]
# 2. Get clean baseline
with torch.no_grad():
clean_outputs = model(**clean_inputs)
clean_prob = torch.softmax(clean_outputs.logits[0, -1], dim=-1)[seattle_token].item()
print(f"Clean P(Seattle): {clean_prob:.4f}")
# 3. Sweep over layers - corrupt input, restore at each layer
results = []
for restore_layer in range(model.config.n_layer):
# Config: add noise at input, restore at target layer
config = pv.IntervenableConfig(
representations=[
# Noise intervention at embedding
pv.RepresentationConfig(
layer=0,
component="block_input",
intervention_type=pv.NoiseIntervention,
),
# Restore clean at target layer
pv.RepresentationConfig(
layer=restore_layer,
component="block_output",
intervention_type=pv.VanillaIntervention,
),
]
)
intervenable = pv.IntervenableModel(config, model)
# Source is clean (for restoration), base gets noise
_, outputs = intervenable(
base=clean_inputs,
sources=[clean_inputs], # Restore from clean
)
prob = torch.softmax(outputs.logits[0, -1], dim=-1)[seattle_token].item()
results.append(prob)
print(f"Restore at layer {restore_layer}: P(Seattle) = {prob:.4f}")
# 4. Find critical layers (where restoration helps most)
import numpy as np
results = np.array(results)
critical_layers = np.argsort(results)[-5:]
print(f"\nMost critical layers: {critical_layers}")
```
---
## Tutorial 3: Trainable Interventions (DAS)
### Goal
Learn a low-rank intervention that achieves a target counterfactual behavior.
### Step-by-Step
```python
import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# 1. Define trainable intervention
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=6,
component="block_output",
intervention_type=pv.LowRankRotatedSpaceIntervention,
low_rank_dimension=64, # Learn 64-dim subspace
)
]
)
intervenable = pv.IntervenableModel(config, model)
# 2. Setup optimizer
optimizer = torch.optim.Adam(
intervenable.get_trainable_parameters(),
lr=1e-3
)
# 3. Training data (simplified example)
# Goal: Make model predict "Paris" instead of "Rome"
base_prompt = "The capital of Italy is"
target_token = tokenizer.encode(" Paris")[0]
base_inputs = tokenizer(base_prompt, return_tensors="pt")
# 4. Training loop
for step in range(100):
optimizer.zero_grad()
_, outputs = intervenable(
base=base_inputs,
sources=[base_inputs], # Self-intervention
)
# Loss: maximize probability of target token
logits = outputs.logits[0, -1]
loss = -torch.log_softmax(logits, dim=-1)[target_token]
loss.backward()
optimizer.step()
if step % 20 == 0:
prob = torch.softmax(logits.detach(), dim=-1)[target_token].item()
print(f"Step {step}: loss={loss.item():.4f}, P(Paris)={prob:.4f}")
# 5. Analyze learned rotation
rotation = intervenable.interventions["layer.6.comp.block_output.unit.pos.nunit.1#0"][0]
print(f"Learned rotation shape: {rotation.rotate_layer.weight.shape}")
```
---
## Tutorial 4: Position-Specific Intervention
### Goal
Intervene at specific token positions only.
### Step-by-Step
```python
import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# 1. Setup
base_prompt = "John and Mary went to the store"
source_prompt = "Alice and Bob went to the store"
base_inputs = tokenizer(base_prompt, return_tensors="pt")
source_inputs = tokenizer(source_prompt, return_tensors="pt")
# 2. Position-specific config
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=5,
component="block_output",
intervention_type=pv.VanillaIntervention,
unit="pos",
max_number_of_units=1, # Single position
)
]
)
intervenable = pv.IntervenableModel(config, model)
# 3. Intervene at position 0 only (first name)
_, outputs = intervenable(
base=base_inputs,
sources=[source_inputs],
unit_locations={"sources->base": ([[[0]]], [[[0]]])},
)
# 4. Intervene at multiple positions
_, outputs = intervenable(
base=base_inputs,
sources=[source_inputs],
unit_locations={"sources->base": ([[[0, 2]]], [[[0, 2]]])},
)
```
---
## Tutorial 5: Collecting Activations
### Goal
Extract activations without modifying them.
### Step-by-Step
```python
import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# 1. Config with CollectIntervention
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=5,
component="block_output",
intervention_type=pv.CollectIntervention,
),
pv.RepresentationConfig(
layer=10,
component="attention_output",
intervention_type=pv.CollectIntervention,
),
]
)
intervenable = pv.IntervenableModel(config, model)
# 2. Run and collect
inputs = tokenizer("Hello world", return_tensors="pt")
_, collected = intervenable(base=inputs)
# 3. Access collected activations
layer5_output = collected[0]
layer10_attn = collected[1]
print(f"Layer 5 block output shape: {layer5_output.shape}")
print(f"Layer 10 attention output shape: {layer10_attn.shape}")
```
---
## Tutorial 6: Generation with Interventions
### Goal
Apply interventions during text generation.
### Step-by-Step
```python
import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
# 1. Get steering direction (happy vs sad)
happy_inputs = tokenizer("I am very happy and", return_tensors="pt")
sad_inputs = tokenizer("I am very sad and", return_tensors="pt")
# Collect activations
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=6,
component="mlp_output",
intervention_type=pv.CollectIntervention,
)
]
)
collector = pv.IntervenableModel(config, model)
_, happy_acts = collector(base=happy_inputs)
_, sad_acts = collector(base=sad_inputs)
steering_direction = happy_acts[0].mean(dim=1) - sad_acts[0].mean(dim=1)
# 2. Config for steering during generation
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=6,
component="mlp_output",
intervention_type=pv.AdditionIntervention,
)
]
)
intervenable = pv.IntervenableModel(config, model)
# 3. Generate with steering
prompt = "Today I feel"
inputs = tokenizer(prompt, return_tensors="pt")
# Create source with steering direction
# (This is simplified - actual implementation varies)
output = intervenable.generate(
inputs,
max_new_tokens=20,
do_sample=True,
temperature=0.7,
)
print(tokenizer.decode(output[0]))
```
---
## External Resources
### Official Tutorials
- [pyvene 101](https://stanfordnlp.github.io/pyvene/tutorials/pyvene_101.html)
- [Causal Tracing](https://stanfordnlp.github.io/pyvene/tutorials/advanced_tutorials/Causal_Tracing.html)
- [DAS Introduction](https://stanfordnlp.github.io/pyvene/tutorials/advanced_tutorials/DAS_Main_Introduction.html)
- [IOI Replication](https://stanfordnlp.github.io/pyvene/tutorials/advanced_tutorials/IOI_Replication.html)
### Papers
- [pyvene Paper](https://arxiv.org/abs/2403.07809) - NAACL 2024
- [ROME](https://arxiv.org/abs/2202.05262) - Meng et al. (2022)
- [Inference-Time Intervention](https://arxiv.org/abs/2306.03341) - Li et al. (2023)
@@ -0,0 +1,386 @@
---
name: sparse-autoencoder-training
description: Provides guidance for training and analyzing Sparse Autoencoders (SAEs) using SAELens to decompose neural network activations into interpretable features. Use when discovering interpretable features, analyzing superposition, or studying monosemantic representations in language models.
version: 1.0.0
author: Orchestra Research
license: MIT
tags: [Sparse Autoencoders, SAE, Mechanistic Interpretability, Feature Discovery, Superposition]
dependencies: [sae-lens>=6.0.0, transformer-lens>=2.0.0, torch>=2.0.0]
---
# SAELens: Sparse Autoencoders for Mechanistic Interpretability
SAELens is the primary library for training and analyzing Sparse Autoencoders (SAEs) - a technique for decomposing polysemantic neural network activations into sparse, interpretable features. Based on Anthropic's groundbreaking research on monosemanticity.
**GitHub**: [jbloomAus/SAELens](https://github.com/jbloomAus/SAELens) (1,100+ stars)
## The Problem: Polysemanticity & Superposition
Individual neurons in neural networks are **polysemantic** - they activate in multiple, semantically distinct contexts. This happens because models use **superposition** to represent more features than they have neurons, making interpretability difficult.
**SAEs solve this** by decomposing dense activations into sparse, monosemantic features - typically only a small number of features activate for any given input, and each feature corresponds to an interpretable concept.
## When to Use SAELens
**Use SAELens when you need to:**
- Discover interpretable features in model activations
- Understand what concepts a model has learned
- Study superposition and feature geometry
- Perform feature-based steering or ablation
- Analyze safety-relevant features (deception, bias, harmful content)
**Consider alternatives when:**
- You need basic activation analysis → Use **TransformerLens** directly
- You want causal intervention experiments → Use **pyvene** or **TransformerLens**
- You need production steering → Consider direct activation engineering
## Installation
```bash
pip install sae-lens
```
Requirements: Python 3.10+, transformer-lens>=2.0.0
## Core Concepts
### What SAEs Learn
SAEs are trained to reconstruct model activations through a sparse bottleneck:
```
Input Activation → Encoder → Sparse Features → Decoder → Reconstructed Activation
(d_model) ↓ (d_sae >> d_model) ↓ (d_model)
sparsity reconstruction
penalty loss
```
**Loss Function**: `MSE(original, reconstructed) + L1_coefficient × L1(features)`
### Key Validation (Anthropic Research)
In "Towards Monosemanticity", human evaluators found **70% of SAE features genuinely interpretable**. Features discovered include:
- DNA sequences, legal language, HTTP requests
- Hebrew text, nutrition statements, code syntax
- Sentiment, named entities, grammatical structures
## Workflow 1: Loading and Analyzing Pre-trained SAEs
### Step-by-Step
```python
from transformer_lens import HookedTransformer
from sae_lens import SAE
# 1. Load model and pre-trained SAE
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, cfg_dict, sparsity = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
# 2. Get model activations
tokens = model.to_tokens("The capital of France is Paris")
_, cache = model.run_with_cache(tokens)
activations = cache["resid_pre", 8] # [batch, pos, d_model]
# 3. Encode to SAE features
sae_features = sae.encode(activations) # [batch, pos, d_sae]
print(f"Active features: {(sae_features > 0).sum()}")
# 4. Find top features for each position
for pos in range(tokens.shape[1]):
top_features = sae_features[0, pos].topk(5)
token = model.to_str_tokens(tokens[0, pos:pos+1])[0]
print(f"Token '{token}': features {top_features.indices.tolist()}")
# 5. Reconstruct activations
reconstructed = sae.decode(sae_features)
reconstruction_error = (activations - reconstructed).norm()
```
### Available Pre-trained SAEs
| Release | Model | Layers |
|---------|-------|--------|
| `gpt2-small-res-jb` | GPT-2 Small | Multiple residual streams |
| `gemma-2b-res` | Gemma 2B | Residual streams |
| Various on HuggingFace | Search tag `saelens` | Various |
### Checklist
- [ ] Load model with TransformerLens
- [ ] Load matching SAE for target layer
- [ ] Encode activations to sparse features
- [ ] Identify top-activating features per token
- [ ] Validate reconstruction quality
## Workflow 2: Training a Custom SAE
### Step-by-Step
```python
from sae_lens import SAE, LanguageModelSAERunnerConfig, SAETrainingRunner
# 1. Configure training
cfg = LanguageModelSAERunnerConfig(
# Model
model_name="gpt2-small",
hook_name="blocks.8.hook_resid_pre",
hook_layer=8,
d_in=768, # Model dimension
# SAE architecture
architecture="standard", # or "gated", "topk"
d_sae=768 * 8, # Expansion factor of 8
activation_fn="relu",
# Training
lr=4e-4,
l1_coefficient=8e-5, # Sparsity penalty
l1_warm_up_steps=1000,
train_batch_size_tokens=4096,
training_tokens=100_000_000,
# Data
dataset_path="monology/pile-uncopyrighted",
context_size=128,
# Logging
log_to_wandb=True,
wandb_project="sae-training",
# Checkpointing
checkpoint_path="checkpoints",
n_checkpoints=5,
)
# 2. Train
trainer = SAETrainingRunner(cfg)
sae = trainer.run()
# 3. Evaluate
print(f"L0 (avg active features): {trainer.metrics['l0']}")
print(f"CE Loss Recovered: {trainer.metrics['ce_loss_score']}")
```
### Key Hyperparameters
| Parameter | Typical Value | Effect |
|-----------|---------------|--------|
| `d_sae` | 4-16× d_model | More features, higher capacity |
| `l1_coefficient` | 5e-5 to 1e-4 | Higher = sparser, less accurate |
| `lr` | 1e-4 to 1e-3 | Standard optimizer LR |
| `l1_warm_up_steps` | 500-2000 | Prevents early feature death |
### Evaluation Metrics
| Metric | Target | Meaning |
|--------|--------|---------|
| **L0** | 50-200 | Average active features per token |
| **CE Loss Score** | 80-95% | Cross-entropy recovered vs original |
| **Dead Features** | <5% | Features that never activate |
| **Explained Variance** | >90% | Reconstruction quality |
### Checklist
- [ ] Choose target layer and hook point
- [ ] Set expansion factor (d_sae = 4-16× d_model)
- [ ] Tune L1 coefficient for desired sparsity
- [ ] Enable L1 warm-up to prevent dead features
- [ ] Monitor metrics during training (W&B)
- [ ] Validate L0 and CE loss recovery
- [ ] Check dead feature ratio
## Workflow 3: Feature Analysis and Steering
### Analyzing Individual Features
```python
from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, _, _ = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
# Find what activates a specific feature
feature_idx = 1234
test_texts = [
"The scientist conducted an experiment",
"I love chocolate cake",
"The code compiles successfully",
"Paris is beautiful in spring",
]
for text in test_texts:
tokens = model.to_tokens(text)
_, cache = model.run_with_cache(tokens)
features = sae.encode(cache["resid_pre", 8])
activation = features[0, :, feature_idx].max().item()
print(f"{activation:.3f}: {text}")
```
### Feature Steering
```python
def steer_with_feature(model, sae, prompt, feature_idx, strength=5.0):
"""Add SAE feature direction to residual stream."""
tokens = model.to_tokens(prompt)
# Get feature direction from decoder
feature_direction = sae.W_dec[feature_idx] # [d_model]
def steering_hook(activation, hook):
# Add scaled feature direction at all positions
activation += strength * feature_direction
return activation
# Generate with steering
output = model.generate(
tokens,
max_new_tokens=50,
fwd_hooks=[("blocks.8.hook_resid_pre", steering_hook)]
)
return model.to_string(output[0])
```
### Feature Attribution
```python
# Which features most affect a specific output?
tokens = model.to_tokens("The capital of France is")
_, cache = model.run_with_cache(tokens)
# Get features at final position
features = sae.encode(cache["resid_pre", 8])[0, -1] # [d_sae]
# Get logit attribution per feature
# Feature contribution = feature_activation × decoder_weight × unembedding
W_dec = sae.W_dec # [d_sae, d_model]
W_U = model.W_U # [d_model, vocab]
# Contribution to "Paris" logit
paris_token = model.to_single_token(" Paris")
feature_contributions = features * (W_dec @ W_U[:, paris_token])
top_features = feature_contributions.topk(10)
print("Top features for 'Paris' prediction:")
for idx, val in zip(top_features.indices, top_features.values):
print(f" Feature {idx.item()}: {val.item():.3f}")
```
## Common Issues & Solutions
### Issue: High dead feature ratio
```python
# WRONG: No warm-up, features die early
cfg = LanguageModelSAERunnerConfig(
l1_coefficient=1e-4,
l1_warm_up_steps=0, # Bad!
)
# RIGHT: Warm-up L1 penalty
cfg = LanguageModelSAERunnerConfig(
l1_coefficient=8e-5,
l1_warm_up_steps=1000, # Gradually increase
use_ghost_grads=True, # Revive dead features
)
```
### Issue: Poor reconstruction (low CE recovery)
```python
# Reduce sparsity penalty
cfg = LanguageModelSAERunnerConfig(
l1_coefficient=5e-5, # Lower = better reconstruction
d_sae=768 * 16, # More capacity
)
```
### Issue: Features not interpretable
```python
# Increase sparsity (higher L1)
cfg = LanguageModelSAERunnerConfig(
l1_coefficient=1e-4, # Higher = sparser, more interpretable
)
# Or use TopK architecture
cfg = LanguageModelSAERunnerConfig(
architecture="topk",
activation_fn_kwargs={"k": 50}, # Exactly 50 active features
)
```
### Issue: Memory errors during training
```python
cfg = LanguageModelSAERunnerConfig(
train_batch_size_tokens=2048, # Reduce batch size
store_batch_size_prompts=4, # Fewer prompts in buffer
n_batches_in_buffer=8, # Smaller activation buffer
)
```
## Integration with Neuronpedia
Browse pre-trained SAE features at [neuronpedia.org](https://neuronpedia.org):
```python
# Features are indexed by SAE ID
# Example: gpt2-small layer 8 feature 1234
# → neuronpedia.org/gpt2-small/8-res-jb/1234
```
## Key Classes Reference
| Class | Purpose |
|-------|---------|
| `SAE` | Sparse Autoencoder model |
| `LanguageModelSAERunnerConfig` | Training configuration |
| `SAETrainingRunner` | Training loop manager |
| `ActivationsStore` | Activation collection and batching |
| `HookedSAETransformer` | TransformerLens + SAE integration |
## Reference Documentation
For detailed API documentation, tutorials, and advanced usage, see the `references/` folder:
| File | Contents |
|------|----------|
| [references/README.md](references/README.md) | Overview and quick start guide |
| [references/api.md](references/api.md) | Complete API reference for SAE, TrainingSAE, configurations |
| [references/tutorials.md](references/tutorials.md) | Step-by-step tutorials for training, analysis, steering |
## External Resources
### Tutorials
- [Basic Loading & Analysis](https://github.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
- [Training a Sparse Autoencoder](https://github.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
- [ARENA SAE Curriculum](https://www.lesswrong.com/posts/LnHowHgmrMbWtpkxx/intro-to-superposition-and-sparse-autoencoders-colab)
### Papers
- [Towards Monosemanticity](https://transformer-circuits.pub/2023/monosemantic-features) - Anthropic (2023)
- [Scaling Monosemanticity](https://transformer-circuits.pub/2024/scaling-monosemanticity/) - Anthropic (2024)
- [Sparse Autoencoders Find Highly Interpretable Features](https://arxiv.org/abs/2309.08600) - Cunningham et al. (ICLR 2024)
### Official Documentation
- [SAELens Docs](https://jbloomaus.github.io/SAELens/)
- [Neuronpedia](https://neuronpedia.org) - Feature browser
## SAE Architectures
| Architecture | Description | Use Case |
|--------------|-------------|----------|
| **Standard** | ReLU + L1 penalty | General purpose |
| **Gated** | Learned gating mechanism | Better sparsity control |
| **TopK** | Exactly K active features | Consistent sparsity |
```python
# TopK SAE (exactly 50 features active)
cfg = LanguageModelSAERunnerConfig(
architecture="topk",
activation_fn="topk",
activation_fn_kwargs={"k": 50},
)
```
@@ -0,0 +1,70 @@
# SAELens Reference Documentation
This directory contains comprehensive reference materials for SAELens.
## Contents
- [api.md](api.md) - Complete API reference for SAE, TrainingSAE, and configuration classes
- [tutorials.md](tutorials.md) - Step-by-step tutorials for training and analyzing SAEs
- [papers.md](papers.md) - Key research papers on sparse autoencoders
## Quick Links
- **GitHub Repository**: https://github.com/jbloomAus/SAELens
- **Neuronpedia**: https://neuronpedia.org (browse pre-trained SAE features)
- **HuggingFace SAEs**: Search for tag `saelens`
## Installation
```bash
pip install sae-lens
```
Requirements: Python 3.10+, transformer-lens>=2.0.0
## Basic Usage
```python
from transformer_lens import HookedTransformer
from sae_lens import SAE
# Load model and SAE
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, cfg_dict, sparsity = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
# Encode activations to sparse features
tokens = model.to_tokens("Hello world")
_, cache = model.run_with_cache(tokens)
activations = cache["resid_pre", 8]
features = sae.encode(activations) # Sparse feature activations
reconstructed = sae.decode(features) # Reconstructed activations
```
## Key Concepts
### Sparse Autoencoders
SAEs decompose dense neural activations into sparse, interpretable features:
- **Encoder**: Maps d_model → d_sae (typically 4-16x expansion)
- **ReLU/TopK**: Enforces sparsity
- **Decoder**: Reconstructs original activations
### Training Loss
`Loss = MSE(original, reconstructed) + L1_coefficient × L1(features)`
### Key Metrics
- **L0**: Average number of active features (target: 50-200)
- **CE Loss Score**: Cross-entropy recovered vs original model (target: 80-95%)
- **Dead Features**: Features that never activate (target: <5%)
## Available Pre-trained SAEs
| Release | Model | Description |
|---------|-------|-------------|
| `gpt2-small-res-jb` | GPT-2 Small | Residual stream SAEs |
| `gemma-2b-res` | Gemma 2B | Residual stream SAEs |
| Various | Search HuggingFace | Community-trained SAEs |
@@ -0,0 +1,333 @@
# SAELens API Reference
## SAE Class
The core class representing a Sparse Autoencoder.
### Loading Pre-trained SAEs
```python
from sae_lens import SAE
# From official releases
sae, cfg_dict, sparsity = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
# From HuggingFace
sae, cfg_dict, sparsity = SAE.from_pretrained(
release="username/repo-name",
sae_id="path/to/sae",
device="cuda"
)
# From local disk
sae = SAE.load_from_disk("/path/to/sae", device="cuda")
```
### SAE Attributes
| Attribute | Shape | Description |
|-----------|-------|-------------|
| `W_enc` | [d_in, d_sae] | Encoder weights |
| `W_dec` | [d_sae, d_in] | Decoder weights |
| `b_enc` | [d_sae] | Encoder bias |
| `b_dec` | [d_in] | Decoder bias |
| `cfg` | SAEConfig | Configuration object |
### Core Methods
#### encode()
```python
# Encode activations to sparse features
features = sae.encode(activations)
# Input: [batch, pos, d_in]
# Output: [batch, pos, d_sae]
```
#### decode()
```python
# Reconstruct activations from features
reconstructed = sae.decode(features)
# Input: [batch, pos, d_sae]
# Output: [batch, pos, d_in]
```
#### forward()
```python
# Full forward pass (encode + decode)
reconstructed = sae(activations)
# Returns reconstructed activations
```
#### save_model()
```python
sae.save_model("/path/to/save")
```
---
## SAEConfig
Configuration class for SAE architecture and training context.
### Key Parameters
| Parameter | Type | Description |
|-----------|------|-------------|
| `d_in` | int | Input dimension (model's d_model) |
| `d_sae` | int | SAE hidden dimension |
| `architecture` | str | "standard", "gated", "jumprelu", "topk" |
| `activation_fn_str` | str | Activation function name |
| `model_name` | str | Source model name |
| `hook_name` | str | Hook point in model |
| `normalize_activations` | str | Normalization method |
| `dtype` | str | Data type |
| `device` | str | Device |
### Accessing Config
```python
print(sae.cfg.d_in) # 768 for GPT-2 small
print(sae.cfg.d_sae) # e.g., 24576 (32x expansion)
print(sae.cfg.hook_name) # e.g., "blocks.8.hook_resid_pre"
```
---
## LanguageModelSAERunnerConfig
Comprehensive configuration for training SAEs.
### Example Configuration
```python
from sae_lens import LanguageModelSAERunnerConfig
cfg = LanguageModelSAERunnerConfig(
# Model and hook
model_name="gpt2-small",
hook_name="blocks.8.hook_resid_pre",
hook_layer=8,
d_in=768,
# SAE architecture
architecture="standard", # "standard", "gated", "jumprelu", "topk"
d_sae=768 * 8, # Expansion factor
activation_fn="relu",
# Training hyperparameters
lr=4e-4,
l1_coefficient=8e-5,
lp_norm=1.0,
lr_scheduler_name="constant",
lr_warm_up_steps=500,
# Sparsity control
l1_warm_up_steps=1000,
use_ghost_grads=True,
feature_sampling_window=1000,
dead_feature_window=5000,
dead_feature_threshold=1e-8,
# Data
dataset_path="monology/pile-uncopyrighted",
streaming=True,
context_size=128,
# Batch sizes
train_batch_size_tokens=4096,
store_batch_size_prompts=16,
n_batches_in_buffer=64,
# Training duration
training_tokens=100_000_000,
# Logging
log_to_wandb=True,
wandb_project="sae-training",
wandb_log_frequency=100,
# Checkpointing
checkpoint_path="checkpoints",
n_checkpoints=5,
# Hardware
device="cuda",
dtype="float32",
)
```
### Key Parameters Explained
#### Architecture Parameters
| Parameter | Description |
|-----------|-------------|
| `architecture` | SAE type: "standard", "gated", "jumprelu", "topk" |
| `d_sae` | Hidden dimension (or use `expansion_factor`) |
| `expansion_factor` | Alternative to d_sae: d_sae = d_in × expansion_factor |
| `activation_fn` | "relu", "topk", etc. |
| `activation_fn_kwargs` | Dict for activation params (e.g., {"k": 50} for topk) |
#### Sparsity Parameters
| Parameter | Description |
|-----------|-------------|
| `l1_coefficient` | L1 penalty weight (higher = sparser) |
| `l1_warm_up_steps` | Steps to ramp up L1 penalty |
| `use_ghost_grads` | Apply gradients to dead features |
| `dead_feature_threshold` | Activation threshold for "dead" |
| `dead_feature_window` | Steps to check for dead features |
#### Learning Rate Parameters
| Parameter | Description |
|-----------|-------------|
| `lr` | Base learning rate |
| `lr_scheduler_name` | "constant", "cosineannealing", etc. |
| `lr_warm_up_steps` | LR warmup steps |
| `lr_decay_steps` | Steps for LR decay |
---
## SAETrainingRunner
Main class for executing training.
### Basic Training
```python
from sae_lens import SAETrainingRunner, LanguageModelSAERunnerConfig
cfg = LanguageModelSAERunnerConfig(...)
runner = SAETrainingRunner(cfg)
sae = runner.run()
```
### Accessing Training Metrics
```python
# During training, metrics logged to W&B include:
# - l0: Average active features
# - ce_loss_score: Cross-entropy recovery
# - mse_loss: Reconstruction loss
# - l1_loss: Sparsity loss
# - dead_features: Count of dead features
```
---
## ActivationsStore
Manages activation collection and batching.
### Basic Usage
```python
from sae_lens import ActivationsStore
store = ActivationsStore.from_sae(
model=model,
sae=sae,
store_batch_size_prompts=8,
train_batch_size_tokens=4096,
n_batches_in_buffer=32,
device="cuda",
)
# Get batch of activations
activations = store.get_batch_tokens()
```
---
## HookedSAETransformer
Integration of SAEs with TransformerLens models.
### Basic Usage
```python
from sae_lens import HookedSAETransformer
# Load model with SAE
model = HookedSAETransformer.from_pretrained("gpt2-small")
model.add_sae(sae)
# Run with SAE in the loop
output = model.run_with_saes(tokens, saes=[sae])
# Cache with SAE activations
output, cache = model.run_with_cache_with_saes(tokens, saes=[sae])
```
---
## SAE Architectures
### Standard (ReLU + L1)
```python
cfg = LanguageModelSAERunnerConfig(
architecture="standard",
activation_fn="relu",
l1_coefficient=8e-5,
)
```
### Gated
```python
cfg = LanguageModelSAERunnerConfig(
architecture="gated",
)
```
### TopK
```python
cfg = LanguageModelSAERunnerConfig(
architecture="topk",
activation_fn="topk",
activation_fn_kwargs={"k": 50}, # Exactly 50 active features
)
```
### JumpReLU (State-of-the-art)
```python
cfg = LanguageModelSAERunnerConfig(
architecture="jumprelu",
)
```
---
## Utility Functions
### Upload to HuggingFace
```python
from sae_lens import upload_saes_to_huggingface
upload_saes_to_huggingface(
saes=[sae],
repo_id="username/my-saes",
token="hf_token",
)
```
### Neuronpedia Integration
```python
# Features can be viewed on Neuronpedia
# URL format: neuronpedia.org/{model}/{layer}-{sae_type}/{feature_id}
# Example: neuronpedia.org/gpt2-small/8-res-jb/1234
```
@@ -0,0 +1,318 @@
# SAELens Tutorials
## Tutorial 1: Loading and Analyzing Pre-trained SAEs
### Goal
Load a pre-trained SAE and analyze which features activate on specific inputs.
### Step-by-Step
```python
from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch
# 1. Load model and SAE
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, cfg_dict, sparsity = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
print(f"SAE input dim: {sae.cfg.d_in}")
print(f"SAE hidden dim: {sae.cfg.d_sae}")
print(f"Expansion factor: {sae.cfg.d_sae / sae.cfg.d_in:.1f}x")
# 2. Get model activations
prompt = "The capital of France is Paris"
tokens = model.to_tokens(prompt)
_, cache = model.run_with_cache(tokens)
activations = cache["resid_pre", 8] # [1, seq_len, 768]
# 3. Encode to SAE features
features = sae.encode(activations) # [1, seq_len, d_sae]
# 4. Analyze sparsity
active_per_token = (features > 0).sum(dim=-1)
print(f"Average active features per token: {active_per_token.float().mean():.1f}")
# 5. Find top features for each token
str_tokens = model.to_str_tokens(prompt)
for pos in range(len(str_tokens)):
top_features = features[0, pos].topk(5)
print(f"\nToken '{str_tokens[pos]}':")
for feat_idx, feat_val in zip(top_features.indices, top_features.values):
print(f" Feature {feat_idx.item()}: {feat_val.item():.3f}")
# 6. Check reconstruction quality
reconstructed = sae.decode(features)
mse = ((activations - reconstructed) ** 2).mean()
print(f"\nReconstruction MSE: {mse.item():.6f}")
```
---
## Tutorial 2: Training a Custom SAE
### Goal
Train a Sparse Autoencoder on GPT-2 activations.
### Step-by-Step
```python
from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner
# 1. Configure training
cfg = LanguageModelSAERunnerConfig(
# Model
model_name="gpt2-small",
hook_name="blocks.6.hook_resid_pre",
hook_layer=6,
d_in=768,
# SAE architecture
architecture="standard",
d_sae=768 * 8, # 8x expansion
activation_fn="relu",
# Training
lr=4e-4,
l1_coefficient=8e-5,
l1_warm_up_steps=1000,
train_batch_size_tokens=4096,
training_tokens=10_000_000, # Small run for demo
# Data
dataset_path="monology/pile-uncopyrighted",
streaming=True,
context_size=128,
# Dead feature prevention
use_ghost_grads=True,
dead_feature_window=5000,
# Logging
log_to_wandb=True,
wandb_project="sae-training-demo",
# Hardware
device="cuda",
dtype="float32",
)
# 2. Train
runner = SAETrainingRunner(cfg)
sae = runner.run()
# 3. Save
sae.save_model("./my_trained_sae")
```
### Hyperparameter Tuning Guide
| If you see... | Try... |
|---------------|--------|
| High L0 (>200) | Increase `l1_coefficient` |
| Low CE recovery (<80%) | Decrease `l1_coefficient`, increase `d_sae` |
| Many dead features (>5%) | Enable `use_ghost_grads`, increase `l1_warm_up_steps` |
| Training instability | Lower `lr`, increase `lr_warm_up_steps` |
---
## Tutorial 3: Feature Attribution and Steering
### Goal
Identify which SAE features contribute to specific predictions and use them for steering.
### Step-by-Step
```python
from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, _, _ = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
# 1. Feature attribution for a specific prediction
prompt = "The capital of France is"
tokens = model.to_tokens(prompt)
_, cache = model.run_with_cache(tokens)
activations = cache["resid_pre", 8]
features = sae.encode(activations)
# Target token
target_token = model.to_single_token(" Paris")
# Compute feature contributions to target logit
# contribution = feature_activation * decoder_weight * unembedding
W_dec = sae.W_dec # [d_sae, d_model]
W_U = model.W_U # [d_model, d_vocab]
# Feature direction projected to vocabulary
feature_to_logit = W_dec @ W_U # [d_sae, d_vocab]
# Contribution of each feature to "Paris" at final position
feature_acts = features[0, -1] # [d_sae]
contributions = feature_acts * feature_to_logit[:, target_token]
# Top contributing features
top_features = contributions.topk(10)
print("Top features contributing to 'Paris':")
for idx, val in zip(top_features.indices, top_features.values):
print(f" Feature {idx.item()}: {val.item():.3f}")
# 2. Feature steering
def steer_with_feature(feature_idx, strength=5.0):
"""Add a feature direction to the residual stream."""
feature_direction = sae.W_dec[feature_idx] # [d_model]
def hook(activation, hook_obj):
activation[:, -1, :] += strength * feature_direction
return activation
output = model.generate(
tokens,
max_new_tokens=10,
fwd_hooks=[("blocks.8.hook_resid_pre", hook)]
)
return model.to_string(output[0])
# Try steering with top feature
top_feature_idx = top_features.indices[0].item()
print(f"\nSteering with feature {top_feature_idx}:")
print(steer_with_feature(top_feature_idx, strength=10.0))
```
---
## Tutorial 4: Feature Ablation
### Goal
Test the causal importance of features by ablating them.
### Step-by-Step
```python
from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, _, _ = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
prompt = "The capital of France is"
tokens = model.to_tokens(prompt)
# Baseline prediction
baseline_logits = model(tokens)
target_token = model.to_single_token(" Paris")
baseline_prob = torch.softmax(baseline_logits[0, -1], dim=-1)[target_token].item()
print(f"Baseline P(Paris): {baseline_prob:.4f}")
# Get features to ablate
_, cache = model.run_with_cache(tokens)
activations = cache["resid_pre", 8]
features = sae.encode(activations)
top_features = features[0, -1].topk(10).indices
# Ablate top features one by one
for feat_idx in top_features:
def ablation_hook(activation, hook, feat_idx=feat_idx):
# Encode → zero feature → decode
feats = sae.encode(activation)
feats[:, :, feat_idx] = 0
return sae.decode(feats)
ablated_logits = model.run_with_hooks(
tokens,
fwd_hooks=[("blocks.8.hook_resid_pre", ablation_hook)]
)
ablated_prob = torch.softmax(ablated_logits[0, -1], dim=-1)[target_token].item()
change = (ablated_prob - baseline_prob) / baseline_prob * 100
print(f"Ablate feature {feat_idx.item()}: P(Paris)={ablated_prob:.4f} ({change:+.1f}%)")
```
---
## Tutorial 5: Comparing Features Across Prompts
### Goal
Find which features activate consistently for a concept.
### Step-by-Step
```python
from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, _, _ = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
# Test prompts about the same concept
prompts = [
"The Eiffel Tower is located in",
"Paris is the capital of",
"France's largest city is",
"The Louvre museum is in",
]
# Collect feature activations
all_features = []
for prompt in prompts:
tokens = model.to_tokens(prompt)
_, cache = model.run_with_cache(tokens)
activations = cache["resid_pre", 8]
features = sae.encode(activations)
# Take max activation across positions
max_features = features[0].max(dim=0).values
all_features.append(max_features)
all_features = torch.stack(all_features) # [n_prompts, d_sae]
# Find features that activate consistently
mean_activation = all_features.mean(dim=0)
min_activation = all_features.min(dim=0).values
# Features active in ALL prompts
consistent_features = (min_activation > 0.5).nonzero().squeeze(-1)
print(f"Features active in all prompts: {len(consistent_features)}")
# Top consistent features
top_consistent = mean_activation[consistent_features].topk(min(10, len(consistent_features)))
print("\nTop consistent features (possibly 'France/Paris' related):")
for idx, val in zip(top_consistent.indices, top_consistent.values):
feat_idx = consistent_features[idx].item()
print(f" Feature {feat_idx}: mean activation {val.item():.3f}")
```
---
## External Resources
### Official Tutorials
- [Basic Loading & Analysis](https://github.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
- [Training SAEs](https://github.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
- [Logits Lens with Features](https://github.com/jbloomAus/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
### ARENA Curriculum
Comprehensive SAE course: https://www.lesswrong.com/posts/LnHowHgmrMbWtpkxx/intro-to-superposition-and-sparse-autoencoders-colab
### Key Papers
- [Towards Monosemanticity](https://transformer-circuits.pub/2023/monosemantic-features) - Anthropic (2023)
- [Scaling Monosemanticity](https://transformer-circuits.pub/2024/scaling-monosemanticity/) - Anthropic (2024)
- [Sparse Autoencoders Find Interpretable Features](https://arxiv.org/abs/2309.08600) - ICLR 2024
@@ -0,0 +1,346 @@
---
name: transformer-lens-interpretability
description: Provides guidance for mechanistic interpretability research using TransformerLens to inspect and manipulate transformer internals via HookPoints and activation caching. Use when reverse-engineering model algorithms, studying attention patterns, or performing activation patching experiments.
version: 1.0.0
author: Orchestra Research
license: MIT
tags: [Mechanistic Interpretability, TransformerLens, Activation Patching, Circuit Analysis]
dependencies: [transformer-lens>=2.0.0, torch>=2.0.0]
---
# TransformerLens: Mechanistic Interpretability for Transformers
TransformerLens is the de facto standard library for mechanistic interpretability research on GPT-style language models. Created by Neel Nanda and maintained by Bryce Meyer, it provides clean interfaces to inspect and manipulate model internals via HookPoints on every activation.
**GitHub**: [TransformerLensOrg/TransformerLens](https://github.com/TransformerLensOrg/TransformerLens) (2,900+ stars)
## When to Use TransformerLens
**Use TransformerLens when you need to:**
- Reverse-engineer algorithms learned during training
- Perform activation patching / causal tracing experiments
- Study attention patterns and information flow
- Analyze circuits (e.g., induction heads, IOI circuit)
- Cache and inspect intermediate activations
- Apply direct logit attribution
**Consider alternatives when:**
- You need to work with non-transformer architectures → Use **nnsight** or **pyvene**
- You want to train/analyze Sparse Autoencoders → Use **SAELens**
- You need remote execution on massive models → Use **nnsight** with NDIF
- You want higher-level causal intervention abstractions → Use **pyvene**
## Installation
```bash
pip install transformer-lens
```
For development version:
```bash
pip install git+https://github.com/TransformerLensOrg/TransformerLens
```
## Core Concepts
### HookedTransformer
The main class that wraps transformer models with HookPoints on every activation:
```python
from transformer_lens import HookedTransformer
# Load a model
model = HookedTransformer.from_pretrained("gpt2-small")
# For gated models (LLaMA, Mistral)
import os
os.environ["HF_TOKEN"] = "your_token"
model = HookedTransformer.from_pretrained("meta-llama/Llama-2-7b-hf")
```
### Supported Models (50+)
| Family | Models |
|--------|--------|
| GPT-2 | gpt2, gpt2-medium, gpt2-large, gpt2-xl |
| LLaMA | llama-7b, llama-13b, llama-2-7b, llama-2-13b |
| EleutherAI | pythia-70m to pythia-12b, gpt-neo, gpt-j-6b |
| Mistral | mistral-7b, mixtral-8x7b |
| Others | phi, qwen, opt, gemma |
### Activation Caching
Run the model and cache all intermediate activations:
```python
# Get all activations
tokens = model.to_tokens("The Eiffel Tower is in")
logits, cache = model.run_with_cache(tokens)
# Access specific activations
residual = cache["resid_post", 5] # Layer 5 residual stream
attn_pattern = cache["pattern", 3] # Layer 3 attention pattern
mlp_out = cache["mlp_out", 7] # Layer 7 MLP output
# Filter which activations to cache (saves memory)
logits, cache = model.run_with_cache(
tokens,
names_filter=lambda name: "resid_post" in name
)
```
### ActivationCache Keys
| Key Pattern | Shape | Description |
|-------------|-------|-------------|
| `resid_pre, layer` | [batch, pos, d_model] | Residual before attention |
| `resid_mid, layer` | [batch, pos, d_model] | Residual after attention |
| `resid_post, layer` | [batch, pos, d_model] | Residual after MLP |
| `attn_out, layer` | [batch, pos, d_model] | Attention output |
| `mlp_out, layer` | [batch, pos, d_model] | MLP output |
| `pattern, layer` | [batch, head, q_pos, k_pos] | Attention pattern (post-softmax) |
| `q, layer` | [batch, pos, head, d_head] | Query vectors |
| `k, layer` | [batch, pos, head, d_head] | Key vectors |
| `v, layer` | [batch, pos, head, d_head] | Value vectors |
## Workflow 1: Activation Patching (Causal Tracing)
Identify which activations causally affect model output by patching clean activations into corrupted runs.
### Step-by-Step
```python
from transformer_lens import HookedTransformer, patching
import torch
model = HookedTransformer.from_pretrained("gpt2-small")
# 1. Define clean and corrupted prompts
clean_prompt = "The Eiffel Tower is in the city of"
corrupted_prompt = "The Colosseum is in the city of"
clean_tokens = model.to_tokens(clean_prompt)
corrupted_tokens = model.to_tokens(corrupted_prompt)
# 2. Get clean activations
_, clean_cache = model.run_with_cache(clean_tokens)
# 3. Define metric (e.g., logit difference)
paris_token = model.to_single_token(" Paris")
rome_token = model.to_single_token(" Rome")
def metric(logits):
return logits[0, -1, paris_token] - logits[0, -1, rome_token]
# 4. Patch each position and layer
results = torch.zeros(model.cfg.n_layers, clean_tokens.shape[1])
for layer in range(model.cfg.n_layers):
for pos in range(clean_tokens.shape[1]):
def patch_hook(activation, hook):
activation[0, pos] = clean_cache[hook.name][0, pos]
return activation
patched_logits = model.run_with_hooks(
corrupted_tokens,
fwd_hooks=[(f"blocks.{layer}.hook_resid_post", patch_hook)]
)
results[layer, pos] = metric(patched_logits)
# 5. Visualize results (layer x position heatmap)
```
### Checklist
- [ ] Define clean and corrupted inputs that differ minimally
- [ ] Choose metric that captures behavior difference
- [ ] Cache clean activations
- [ ] Systematically patch each (layer, position) combination
- [ ] Visualize results as heatmap
- [ ] Identify causal hotspots
## Workflow 2: Circuit Analysis (Indirect Object Identification)
Replicate the IOI circuit discovery from "Interpretability in the Wild".
### Step-by-Step
```python
from transformer_lens import HookedTransformer
import torch
model = HookedTransformer.from_pretrained("gpt2-small")
# IOI task: "When John and Mary went to the store, Mary gave a bottle to"
# Model should predict "John" (indirect object)
prompt = "When John and Mary went to the store, Mary gave a bottle to"
tokens = model.to_tokens(prompt)
# 1. Get baseline logits
logits, cache = model.run_with_cache(tokens)
john_token = model.to_single_token(" John")
mary_token = model.to_single_token(" Mary")
# 2. Compute logit difference (IO - S)
logit_diff = logits[0, -1, john_token] - logits[0, -1, mary_token]
print(f"Logit difference: {logit_diff.item():.3f}")
# 3. Direct logit attribution by head
def get_head_contribution(layer, head):
# Project head output to logits
head_out = cache["z", layer][0, :, head, :] # [pos, d_head]
W_O = model.W_O[layer, head] # [d_head, d_model]
W_U = model.W_U # [d_model, vocab]
# Head contribution to logits at final position
contribution = head_out[-1] @ W_O @ W_U
return contribution[john_token] - contribution[mary_token]
# 4. Map all heads
head_contributions = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)
for layer in range(model.cfg.n_layers):
for head in range(model.cfg.n_heads):
head_contributions[layer, head] = get_head_contribution(layer, head)
# 5. Identify top contributing heads (name movers, backup name movers)
```
### Checklist
- [ ] Set up task with clear IO/S tokens
- [ ] Compute baseline logit difference
- [ ] Decompose by attention head contributions
- [ ] Identify key circuit components (name movers, S-inhibition, induction)
- [ ] Validate with ablation experiments
## Workflow 3: Induction Head Detection
Find induction heads that implement [A][B]...[A] → [B] pattern.
```python
from transformer_lens import HookedTransformer
import torch
model = HookedTransformer.from_pretrained("gpt2-small")
# Create repeated sequence: [A][B][A] should predict [B]
repeated_tokens = torch.tensor([[1000, 2000, 1000]]) # Arbitrary tokens
_, cache = model.run_with_cache(repeated_tokens)
# Induction heads attend from final [A] back to first [B]
# Check attention from position 2 to position 1
induction_scores = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)
for layer in range(model.cfg.n_layers):
pattern = cache["pattern", layer][0] # [head, q_pos, k_pos]
# Attention from pos 2 to pos 1
induction_scores[layer] = pattern[:, 2, 1]
# Heads with high scores are induction heads
top_heads = torch.topk(induction_scores.flatten(), k=5)
```
## Common Issues & Solutions
### Issue: Hooks persist after debugging
```python
# WRONG: Old hooks remain active
model.run_with_hooks(tokens, fwd_hooks=[...]) # Debug, add new hooks
model.run_with_hooks(tokens, fwd_hooks=[...]) # Old hooks still there!
# RIGHT: Always reset hooks
model.reset_hooks()
model.run_with_hooks(tokens, fwd_hooks=[...])
```
### Issue: Tokenization gotchas
```python
# WRONG: Assuming consistent tokenization
model.to_tokens("Tim") # Single token
model.to_tokens("Neel") # Becomes "Ne" + "el" (two tokens!)
# RIGHT: Check tokenization explicitly
tokens = model.to_tokens("Neel", prepend_bos=False)
print(model.to_str_tokens(tokens)) # ['Ne', 'el']
```
### Issue: LayerNorm ignored in analysis
```python
# WRONG: Ignoring LayerNorm
pre_activation = residual @ model.W_in[layer]
# RIGHT: Include LayerNorm
ln_scale = model.blocks[layer].ln2.w
ln_out = model.blocks[layer].ln2(residual)
pre_activation = ln_out @ model.W_in[layer]
```
### Issue: Memory explosion with large models
```python
# Use selective caching
logits, cache = model.run_with_cache(
tokens,
names_filter=lambda n: "resid_post" in n or "pattern" in n,
device="cpu" # Cache on CPU
)
```
## Key Classes Reference
| Class | Purpose |
|-------|---------|
| `HookedTransformer` | Main model wrapper with hooks |
| `ActivationCache` | Dictionary-like cache of activations |
| `HookedTransformerConfig` | Model configuration |
| `FactoredMatrix` | Efficient factored matrix operations |
## Integration with SAELens
TransformerLens integrates with SAELens for Sparse Autoencoder analysis:
```python
from transformer_lens import HookedTransformer
from sae_lens import SAE
model = HookedTransformer.from_pretrained("gpt2-small")
sae = SAE.from_pretrained("gpt2-small-res-jb", "blocks.8.hook_resid_pre")
# Run with SAE
tokens = model.to_tokens("Hello world")
_, cache = model.run_with_cache(tokens)
sae_acts = sae.encode(cache["resid_pre", 8])
```
## Reference Documentation
For detailed API documentation, tutorials, and advanced usage, see the `references/` folder:
| File | Contents |
|------|----------|
| [references/README.md](references/README.md) | Overview and quick start guide |
| [references/api.md](references/api.md) | Complete API reference for HookedTransformer, ActivationCache, HookPoints |
| [references/tutorials.md](references/tutorials.md) | Step-by-step tutorials for activation patching, circuit analysis, logit lens |
## External Resources
### Tutorials
- [Main Demo Notebook](https://transformerlensorg.github.io/TransformerLens/generated/demos/Main_Demo.html)
- [Activation Patching Demo](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Activation_Patching_in_TL_Demo.ipynb)
- [ARENA Mech Interp Course](https://arena-foundation.github.io/ARENA/) - 200+ hours of tutorials
### Papers
- [A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html)
- [In-context Learning and Induction Heads](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html)
- [Interpretability in the Wild (IOI)](https://arxiv.org/abs/2211.00593)
### Official Documentation
- [Official Docs](https://transformerlensorg.github.io/TransformerLens/)
- [Model Properties Table](https://transformerlensorg.github.io/TransformerLens/generated/model_properties_table.html)
- [Neel Nanda's Glossary](https://www.neelnanda.io/mechanistic-interpretability/glossary)
## Version Notes
- **v2.0**: Removed HookedSAE (moved to SAELens)
- **v3.0 (alpha)**: TransformerBridge for loading any nn.Module
@@ -0,0 +1,54 @@
# TransformerLens Reference Documentation
This directory contains comprehensive reference materials for TransformerLens.
## Contents
- [api.md](api.md) - Complete API reference for HookedTransformer, ActivationCache, and HookPoints
- [tutorials.md](tutorials.md) - Step-by-step tutorials for common interpretability workflows
- [papers.md](papers.md) - Key research papers and foundational concepts
## Quick Links
- **Official Documentation**: https://transformerlensorg.github.io/TransformerLens/
- **GitHub Repository**: https://github.com/TransformerLensOrg/TransformerLens
- **Model Properties Table**: https://transformerlensorg.github.io/TransformerLens/generated/model_properties_table.html
## Installation
```bash
pip install transformer-lens
```
## Basic Usage
```python
from transformer_lens import HookedTransformer
# Load model
model = HookedTransformer.from_pretrained("gpt2-small")
# Run with activation caching
tokens = model.to_tokens("Hello world")
logits, cache = model.run_with_cache(tokens)
# Access activations
residual = cache["resid_post", 5] # Layer 5 residual stream
attention = cache["pattern", 3] # Layer 3 attention patterns
```
## Key Concepts
### HookPoints
Every activation in the transformer has a HookPoint wrapper, enabling:
- Reading activations via `run_with_cache()`
- Modifying activations via `run_with_hooks()`
### Activation Cache
The `ActivationCache` stores all intermediate activations with helper methods for:
- Residual stream decomposition
- Logit attribution
- Layer-wise analysis
### Supported Models (50+)
GPT-2, LLaMA, Mistral, Pythia, GPT-Neo, OPT, Gemma, Phi, and more.
@@ -0,0 +1,362 @@
# TransformerLens API Reference
## HookedTransformer
The core class for mechanistic interpretability, wrapping transformer models with hooks on every activation.
### Loading Models
```python
from transformer_lens import HookedTransformer
# Basic loading
model = HookedTransformer.from_pretrained("gpt2-small")
# With specific device/dtype
model = HookedTransformer.from_pretrained(
"gpt2-medium",
device="cuda",
dtype=torch.float16
)
# Gated models (LLaMA, Mistral)
import os
os.environ["HF_TOKEN"] = "your_token"
model = HookedTransformer.from_pretrained("meta-llama/Llama-2-7b-hf")
```
### from_pretrained() Parameters
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `model_name` | str | required | Model name from OFFICIAL_MODEL_NAMES |
| `fold_ln` | bool | True | Fold LayerNorm weights into subsequent layers |
| `center_writing_weights` | bool | True | Center residual stream writer means |
| `center_unembed` | bool | True | Center unembedding weights |
| `dtype` | torch.dtype | None | Model precision |
| `device` | str | None | Target device |
| `n_devices` | int | 1 | Number of devices for model parallelism |
### Weight Matrices
| Property | Shape | Description |
|----------|-------|-------------|
| `W_E` | [d_vocab, d_model] | Token embedding matrix |
| `W_U` | [d_model, d_vocab] | Unembedding matrix |
| `W_pos` | [n_ctx, d_model] | Positional embedding |
| `W_Q` | [n_layers, n_heads, d_model, d_head] | Query weights |
| `W_K` | [n_layers, n_heads, d_model, d_head] | Key weights |
| `W_V` | [n_layers, n_heads, d_model, d_head] | Value weights |
| `W_O` | [n_layers, n_heads, d_head, d_model] | Output weights |
| `W_in` | [n_layers, d_model, d_mlp] | MLP input weights |
| `W_out` | [n_layers, d_mlp, d_model] | MLP output weights |
### Core Methods
#### forward()
```python
logits = model(tokens)
logits = model(tokens, return_type="logits")
loss = model(tokens, return_type="loss")
logits, loss = model(tokens, return_type="both")
```
Parameters:
- `input`: Token tensor or string
- `return_type`: "logits", "loss", "both", or None
- `prepend_bos`: Whether to prepend BOS token
- `start_at_layer`: Start execution from specific layer
- `stop_at_layer`: Stop execution at specific layer
#### run_with_cache()
```python
logits, cache = model.run_with_cache(tokens)
# Selective caching (saves memory)
logits, cache = model.run_with_cache(
tokens,
names_filter=lambda name: "resid_post" in name
)
# Cache on CPU
logits, cache = model.run_with_cache(tokens, device="cpu")
```
#### run_with_hooks()
```python
def my_hook(activation, hook):
# Modify activation
activation[:, :, 0] = 0
return activation
logits = model.run_with_hooks(
tokens,
fwd_hooks=[("blocks.5.hook_resid_post", my_hook)]
)
```
#### generate()
```python
output = model.generate(
tokens,
max_new_tokens=50,
temperature=0.7,
top_k=40,
top_p=0.9,
freq_penalty=1.0,
use_past_kv_cache=True
)
```
### Tokenization Methods
```python
# String to tokens
tokens = model.to_tokens("Hello world") # [1, seq_len]
tokens = model.to_tokens("Hello", prepend_bos=False)
# Tokens to string
text = model.to_string(tokens)
# Get string tokens (for debugging)
str_tokens = model.to_str_tokens("Hello world")
# ['<|endoftext|>', 'Hello', ' world']
# Single token validation
token_id = model.to_single_token(" Paris") # Returns int or raises error
```
### Hook Management
```python
# Clear all hooks
model.reset_hooks()
# Add permanent hook
model.add_hook("blocks.0.hook_resid_post", my_hook)
# Remove specific hook
model.remove_hook("blocks.0.hook_resid_post")
```
---
## ActivationCache
Stores and provides access to all activations from a forward pass.
### Accessing Activations
```python
logits, cache = model.run_with_cache(tokens)
# By name and layer
residual = cache["resid_post", 5]
attention = cache["pattern", 3]
mlp_out = cache["mlp_out", 7]
# Full name string
residual = cache["blocks.5.hook_resid_post"]
```
### Cache Keys
| Key Pattern | Shape | Description |
|-------------|-------|-------------|
| `hook_embed` | [batch, pos, d_model] | Token embeddings |
| `hook_pos_embed` | [batch, pos, d_model] | Positional embeddings |
| `resid_pre, layer` | [batch, pos, d_model] | Residual before attention |
| `resid_mid, layer` | [batch, pos, d_model] | Residual after attention |
| `resid_post, layer` | [batch, pos, d_model] | Residual after MLP |
| `attn_out, layer` | [batch, pos, d_model] | Attention output |
| `mlp_out, layer` | [batch, pos, d_model] | MLP output |
| `pattern, layer` | [batch, head, q_pos, k_pos] | Attention pattern (post-softmax) |
| `attn_scores, layer` | [batch, head, q_pos, k_pos] | Attention scores (pre-softmax) |
| `q, layer` | [batch, pos, head, d_head] | Query vectors |
| `k, layer` | [batch, pos, head, d_head] | Key vectors |
| `v, layer` | [batch, pos, head, d_head] | Value vectors |
| `z, layer` | [batch, pos, head, d_head] | Attention output per head |
### Analysis Methods
#### decompose_resid()
Decomposes residual stream into component contributions:
```python
components, labels = cache.decompose_resid(
layer=5,
return_labels=True,
mode="attn" # or "mlp" or "full"
)
```
#### accumulated_resid()
Get accumulated residual at each layer (for Logit Lens):
```python
accumulated = cache.accumulated_resid(
layer=None, # All layers
incl_mid=False,
apply_ln=True # Apply final LayerNorm
)
```
#### logit_attrs()
Calculate logit attribution for components:
```python
attrs = cache.logit_attrs(
residual_stack,
tokens=target_tokens,
incorrect_tokens=incorrect_tokens
)
```
#### stack_head_results()
Stack attention head outputs:
```python
head_results = cache.stack_head_results(
layer=-1, # All layers
pos_slice=None # All positions
)
# Shape: [n_layers, n_heads, batch, pos, d_model]
```
### Utility Methods
```python
# Move cache to device
cache = cache.to("cpu")
# Remove batch dimension (for batch_size=1)
cache = cache.remove_batch_dim()
# Get all keys
keys = cache.keys()
# Iterate
for name, activation in cache.items():
print(name, activation.shape)
```
---
## HookPoint
The fundamental hook mechanism wrapping every activation.
### Hook Function Signature
```python
def hook_fn(activation: torch.Tensor, hook: HookPoint) -> torch.Tensor:
"""
Args:
activation: Current activation value
hook: The HookPoint object (has .name attribute)
Returns:
Modified activation (or None to keep original)
"""
# Modify activation
return activation
```
### Common Hook Patterns
```python
# Zero ablation
def zero_hook(act, hook):
act[:, :, :] = 0
return act
# Mean ablation
def mean_hook(act, hook):
act[:, :, :] = act.mean(dim=0, keepdim=True)
return act
# Patch from cache
def patch_hook(act, hook):
act[:, 5, :] = clean_cache[hook.name][:, 5, :]
return act
# Add steering vector
def steer_hook(act, hook):
act += 0.5 * steering_vector
return act
```
---
## Utility Functions
### patching module
```python
from transformer_lens import patching
# Generic activation patching
results = patching.generic_activation_patch(
model=model,
corrupted_tokens=corrupted,
clean_cache=clean_cache,
patching_metric=metric_fn,
patch_setter=patch_fn,
activation_name="resid_post",
index_axis_names=("layer", "pos")
)
```
### FactoredMatrix
Efficient operations on factored weight matrices:
```python
from transformer_lens import FactoredMatrix
# QK circuit
QK = FactoredMatrix(model.W_Q[layer], model.W_K[layer].T)
# OV circuit
OV = FactoredMatrix(model.W_V[layer], model.W_O[layer])
# Get full matrix
full = QK.AB
# SVD decomposition
U, S, V = QK.svd()
```
---
## Configuration
### HookedTransformerConfig
Key configuration attributes:
| Attribute | Description |
|-----------|-------------|
| `n_layers` | Number of transformer layers |
| `n_heads` | Number of attention heads |
| `d_model` | Model dimension |
| `d_head` | Head dimension |
| `d_mlp` | MLP hidden dimension |
| `d_vocab` | Vocabulary size |
| `n_ctx` | Maximum context length |
| `act_fn` | Activation function name |
| `normalization_type` | "LN" or "LNPre" |
Access via:
```python
model.cfg.n_layers
model.cfg.d_model
```
@@ -0,0 +1,339 @@
# TransformerLens Tutorials
## Tutorial 1: Basic Activation Analysis
### Goal
Understand how to load models, cache activations, and inspect model internals.
### Step-by-Step
```python
from transformer_lens import HookedTransformer
import torch
# 1. Load model
model = HookedTransformer.from_pretrained("gpt2-small")
print(f"Model has {model.cfg.n_layers} layers, {model.cfg.n_heads} heads")
# 2. Tokenize input
prompt = "The capital of France is"
tokens = model.to_tokens(prompt)
print(f"Tokens shape: {tokens.shape}")
print(f"String tokens: {model.to_str_tokens(prompt)}")
# 3. Run with cache
logits, cache = model.run_with_cache(tokens)
print(f"Logits shape: {logits.shape}")
print(f"Cache keys: {len(cache.keys())}")
# 4. Inspect activations
for layer in range(model.cfg.n_layers):
resid = cache["resid_post", layer]
print(f"Layer {layer} residual norm: {resid.norm().item():.2f}")
# 5. Look at attention patterns
attn = cache["pattern", 0] # Layer 0
print(f"Attention shape: {attn.shape}") # [batch, heads, q_pos, k_pos]
# 6. Get top predictions
probs = torch.softmax(logits[0, -1], dim=-1)
top_tokens = probs.topk(5)
for token_id, prob in zip(top_tokens.indices, top_tokens.values):
print(f"{model.to_string(token_id.unsqueeze(0))}: {prob.item():.3f}")
```
---
## Tutorial 2: Activation Patching
### Goal
Identify which activations causally affect model output.
### Concept
1. Run model on "clean" input, cache activations
2. Run model on "corrupted" input
3. Patch clean activations into corrupted run
4. Measure effect on output
### Step-by-Step
```python
from transformer_lens import HookedTransformer
import torch
model = HookedTransformer.from_pretrained("gpt2-small")
# Define clean and corrupted prompts
clean_prompt = "The Eiffel Tower is in the city of"
corrupted_prompt = "The Colosseum is in the city of"
clean_tokens = model.to_tokens(clean_prompt)
corrupted_tokens = model.to_tokens(corrupted_prompt)
# Get clean activations
_, clean_cache = model.run_with_cache(clean_tokens)
# Define metric
paris_token = model.to_single_token(" Paris")
rome_token = model.to_single_token(" Rome")
def logit_diff(logits):
"""Positive = model prefers Paris over Rome"""
return (logits[0, -1, paris_token] - logits[0, -1, rome_token]).item()
# Baseline measurements
clean_logits = model(clean_tokens)
corrupted_logits = model(corrupted_tokens)
print(f"Clean logit diff: {logit_diff(clean_logits):.3f}")
print(f"Corrupted logit diff: {logit_diff(corrupted_logits):.3f}")
# Patch each layer
results = []
for layer in range(model.cfg.n_layers):
def patch_hook(activation, hook, layer=layer):
activation[:] = clean_cache["resid_post", layer]
return activation
patched_logits = model.run_with_hooks(
corrupted_tokens,
fwd_hooks=[(f"blocks.{layer}.hook_resid_post", patch_hook)]
)
results.append(logit_diff(patched_logits))
print(f"Layer {layer}: {results[-1]:.3f}")
# Find most important layer
best_layer = max(range(len(results)), key=lambda i: results[i])
print(f"\nMost important layer: {best_layer}")
```
### Position-Specific Patching
```python
import torch
seq_len = clean_tokens.shape[1]
results = torch.zeros(model.cfg.n_layers, seq_len)
for layer in range(model.cfg.n_layers):
for pos in range(seq_len):
def patch_hook(activation, hook, layer=layer, pos=pos):
activation[:, pos, :] = clean_cache["resid_post", layer][:, pos, :]
return activation
patched_logits = model.run_with_hooks(
corrupted_tokens,
fwd_hooks=[(f"blocks.{layer}.hook_resid_post", patch_hook)]
)
results[layer, pos] = logit_diff(patched_logits)
# Visualize as heatmap
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 8))
plt.imshow(results.numpy(), aspect='auto', cmap='RdBu')
plt.xlabel('Position')
plt.ylabel('Layer')
plt.colorbar(label='Logit Difference')
plt.title('Activation Patching Results')
```
---
## Tutorial 3: Direct Logit Attribution
### Goal
Identify which components (heads, neurons) contribute to specific predictions.
### Step-by-Step
```python
from transformer_lens import HookedTransformer
import torch
model = HookedTransformer.from_pretrained("gpt2-small")
prompt = "The capital of France is"
tokens = model.to_tokens(prompt)
logits, cache = model.run_with_cache(tokens)
# Target token
target_token = model.to_single_token(" Paris")
# Get unembedding direction for target
target_direction = model.W_U[:, target_token] # [d_model]
# Attribution per attention head
head_contributions = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)
for layer in range(model.cfg.n_layers):
# Get per-head output at final position
z = cache["z", layer][0, -1] # [n_heads, d_head]
for head in range(model.cfg.n_heads):
# Project through W_O to get contribution to residual
head_out = z[head] @ model.W_O[layer, head] # [d_model]
# Dot with target direction
contribution = (head_out @ target_direction).item()
head_contributions[layer, head] = contribution
# Find top contributing heads
flat_idx = head_contributions.flatten().topk(10)
print("Top 10 heads for predicting 'Paris':")
for idx, val in zip(flat_idx.indices, flat_idx.values):
layer = idx.item() // model.cfg.n_heads
head = idx.item() % model.cfg.n_heads
print(f" L{layer}H{head}: {val.item():.3f}")
```
---
## Tutorial 4: Induction Head Detection
### Goal
Find attention heads that implement the [A][B]...[A] → [B] pattern.
### Step-by-Step
```python
from transformer_lens import HookedTransformer
import torch
model = HookedTransformer.from_pretrained("gpt2-small")
# Create repeated sequence pattern
# Pattern: [A][B][C][A] - model should attend from last A to B
seq = torch.randint(1000, 5000, (1, 20))
# Repeat first half
seq[0, 10:] = seq[0, :10]
_, cache = model.run_with_cache(seq)
# For induction heads: position i should attend to position (i - seq_len/2 + 1)
# At position 10 (second A), should attend to position 1 (first B)
induction_scores = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)
for layer in range(model.cfg.n_layers):
pattern = cache["pattern", layer][0] # [heads, q_pos, k_pos]
# Check attention from repeated positions to position after first occurrence
for offset in range(1, 10):
q_pos = 10 + offset # Position in second half
k_pos = offset # Should attend to corresponding position in first half
# Average attention to the "correct" position
induction_scores[layer] += pattern[:, q_pos, k_pos]
induction_scores[layer] /= 9 # Average over offsets
# Find top induction heads
print("Top induction heads:")
for layer in range(model.cfg.n_layers):
for head in range(model.cfg.n_heads):
score = induction_scores[layer, head].item()
if score > 0.3:
print(f" L{layer}H{head}: {score:.3f}")
```
---
## Tutorial 5: Logit Lens
### Goal
See what the model "believes" at each layer before final unembedding.
### Step-by-Step
```python
from transformer_lens import HookedTransformer
import torch
model = HookedTransformer.from_pretrained("gpt2-small")
prompt = "The quick brown fox jumps over the lazy"
tokens = model.to_tokens(prompt)
logits, cache = model.run_with_cache(tokens)
# Get accumulated residual at each layer
# Apply LayerNorm to match what unembedding sees
accumulated = cache.accumulated_resid(layer=None, incl_mid=False, apply_ln=True)
# Shape: [n_layers + 1, batch, pos, d_model]
# Project to vocabulary
layer_logits = accumulated @ model.W_U # [n_layers + 1, batch, pos, d_vocab]
# Look at predictions for final position
print("Layer-by-layer predictions for final token:")
for layer in range(model.cfg.n_layers + 1):
probs = torch.softmax(layer_logits[layer, 0, -1], dim=-1)
top_token = probs.argmax()
top_prob = probs[top_token].item()
print(f"Layer {layer}: {model.to_string(top_token.unsqueeze(0))!r} ({top_prob:.3f})")
```
---
## Tutorial 6: Steering with Activation Addition
### Goal
Add a steering vector to change model behavior.
### Step-by-Step
```python
from transformer_lens import HookedTransformer
import torch
model = HookedTransformer.from_pretrained("gpt2-small")
# Get activations for contrasting prompts
positive_prompt = "I love this! It's absolutely wonderful and"
negative_prompt = "I hate this! It's absolutely terrible and"
_, pos_cache = model.run_with_cache(model.to_tokens(positive_prompt))
_, neg_cache = model.run_with_cache(model.to_tokens(negative_prompt))
# Compute steering vector (positive - negative direction)
layer = 6
steering_vector = (
pos_cache["resid_post", layer].mean(dim=1) -
neg_cache["resid_post", layer].mean(dim=1)
)
# Generate with steering
test_prompt = "The movie was"
test_tokens = model.to_tokens(test_prompt)
def steer_hook(activation, hook):
activation += 2.0 * steering_vector
return activation
# Without steering
normal_output = model.generate(test_tokens, max_new_tokens=20)
print(f"Normal: {model.to_string(normal_output[0])}")
# With positive steering
steered_output = model.generate(
test_tokens,
max_new_tokens=20,
fwd_hooks=[(f"blocks.{layer}.hook_resid_post", steer_hook)]
)
print(f"Steered: {model.to_string(steered_output[0])}")
```
---
## External Resources
### Official Tutorials
- [Main Demo](https://transformerlensorg.github.io/TransformerLens/generated/demos/Main_Demo.html)
- [Exploratory Analysis](https://transformerlensorg.github.io/TransformerLens/generated/demos/Exploratory_Analysis_Demo.html)
- [Activation Patching Demo](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Activation_Patching_in_TL_Demo.ipynb)
### ARENA Course
Comprehensive 200+ hour curriculum: https://arena-foundation.github.io/ARENA/
### Neel Nanda's Resources
- [Getting Started in Mech Interp](https://www.neelnanda.io/mechanistic-interpretability/getting-started)
- [Mech Interp Glossary](https://www.neelnanda.io/mechanistic-interpretability/glossary)
- [YouTube Channel](https://www.youtube.com/@neelnanda)
@@ -0,0 +1,5 @@
# Skills Coming Soon
This directory will contain high-quality AI research skills for data processing.
See [CONTRIBUTING.md](../CONTRIBUTING.md) for how to contribute.
@@ -0,0 +1,383 @@
---
name: nemo-curator
description: GPU-accelerated data curation for LLM training. Supports text/image/video/audio. Features fuzzy deduplication (16× faster), quality filtering (30+ heuristics), semantic deduplication, PII redaction, NSFW detection. Scales across GPUs with RAPIDS. Use for preparing high-quality training datasets, cleaning web data, or deduplicating large corpora.
version: 1.0.0
author: Orchestra Research
license: MIT
tags: [Data Processing, NeMo Curator, Data Curation, GPU Acceleration, Deduplication, Quality Filtering, NVIDIA, RAPIDS, PII Redaction, Multimodal, LLM Training Data]
dependencies: [nemo-curator, cudf, dask, rapids]
---
# NeMo Curator - GPU-Accelerated Data Curation
NVIDIA's toolkit for preparing high-quality training data for LLMs.
## When to use NeMo Curator
**Use NeMo Curator when:**
- Preparing LLM training data from web scrapes (Common Crawl)
- Need fast deduplication (16× faster than CPU)
- Curating multi-modal datasets (text, images, video, audio)
- Filtering low-quality or toxic content
- Scaling data processing across GPU cluster
**Performance**:
- **16× faster** fuzzy deduplication (8TB RedPajama v2)
- **40% lower TCO** vs CPU alternatives
- **Near-linear scaling** across GPU nodes
**Use alternatives instead**:
- **datatrove**: CPU-based, open-source data processing
- **dolma**: Allen AI's data toolkit
- **Ray Data**: General ML data processing (no curation focus)
## Quick start
### Installation
```bash
# Text curation (CUDA 12)
uv pip install "nemo-curator[text_cuda12]"
# All modalities
uv pip install "nemo-curator[all_cuda12]"
# CPU-only (slower)
uv pip install "nemo-curator[cpu]"
```
### Basic text curation pipeline
```python
from nemo_curator import ScoreFilter, Modify
from nemo_curator.datasets import DocumentDataset
import pandas as pd
# Load data
df = pd.DataFrame({"text": ["Good document", "Bad doc", "Excellent text"]})
dataset = DocumentDataset(df)
# Quality filtering
def quality_score(doc):
return len(doc["text"].split()) > 5 # Filter short docs
filtered = ScoreFilter(quality_score)(dataset)
# Deduplication
from nemo_curator.modules import ExactDuplicates
deduped = ExactDuplicates()(filtered)
# Save
deduped.to_parquet("curated_data/")
```
## Data curation pipeline
### Stage 1: Quality filtering
```python
from nemo_curator.filters import (
WordCountFilter,
RepeatedLinesFilter,
UrlRatioFilter,
NonAlphaNumericFilter
)
# Apply 30+ heuristic filters
from nemo_curator import ScoreFilter
# Word count filter
dataset = dataset.filter(WordCountFilter(min_words=50, max_words=100000))
# Remove repetitive content
dataset = dataset.filter(RepeatedLinesFilter(max_repeated_line_fraction=0.3))
# URL ratio filter
dataset = dataset.filter(UrlRatioFilter(max_url_ratio=0.2))
```
### Stage 2: Deduplication
**Exact deduplication**:
```python
from nemo_curator.modules import ExactDuplicates
# Remove exact duplicates
deduped = ExactDuplicates(id_field="id", text_field="text")(dataset)
```
**Fuzzy deduplication** (16× faster on GPU):
```python
from nemo_curator.modules import FuzzyDuplicates
# MinHash + LSH deduplication
fuzzy_dedup = FuzzyDuplicates(
id_field="id",
text_field="text",
num_hashes=260, # MinHash parameters
num_buckets=20,
hash_method="md5"
)
deduped = fuzzy_dedup(dataset)
```
**Semantic deduplication**:
```python
from nemo_curator.modules import SemanticDuplicates
# Embedding-based deduplication
semantic_dedup = SemanticDuplicates(
id_field="id",
text_field="text",
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
threshold=0.8 # Cosine similarity threshold
)
deduped = semantic_dedup(dataset)
```
### Stage 3: PII redaction
```python
from nemo_curator.modules import Modify
from nemo_curator.modifiers import PIIRedactor
# Redact personally identifiable information
pii_redactor = PIIRedactor(
supported_entities=["EMAIL_ADDRESS", "PHONE_NUMBER", "PERSON", "LOCATION"],
anonymize_action="replace" # or "redact"
)
redacted = Modify(pii_redactor)(dataset)
```
### Stage 4: Classifier filtering
```python
from nemo_curator.classifiers import QualityClassifier
# Quality classification
quality_clf = QualityClassifier(
model_path="nvidia/quality-classifier-deberta",
batch_size=256,
device="cuda"
)
# Filter low-quality documents
high_quality = dataset.filter(lambda doc: quality_clf(doc["text"]) > 0.5)
```
## GPU acceleration
### GPU vs CPU performance
| Operation | CPU (16 cores) | GPU (A100) | Speedup |
|-----------|----------------|------------|---------|
| Fuzzy dedup (8TB) | 120 hours | 7.5 hours | 16× |
| Exact dedup (1TB) | 8 hours | 0.5 hours | 16× |
| Quality filtering | 2 hours | 0.2 hours | 10× |
### Multi-GPU scaling
```python
from nemo_curator import get_client
import dask_cuda
# Initialize GPU cluster
client = get_client(cluster_type="gpu", n_workers=8)
# Process with 8 GPUs
deduped = FuzzyDuplicates(...)(dataset)
```
## Multi-modal curation
### Image curation
```python
from nemo_curator.image import (
AestheticFilter,
NSFWFilter,
CLIPEmbedder
)
# Aesthetic scoring
aesthetic_filter = AestheticFilter(threshold=5.0)
filtered_images = aesthetic_filter(image_dataset)
# NSFW detection
nsfw_filter = NSFWFilter(threshold=0.9)
safe_images = nsfw_filter(filtered_images)
# Generate CLIP embeddings
clip_embedder = CLIPEmbedder(model="openai/clip-vit-base-patch32")
image_embeddings = clip_embedder(safe_images)
```
### Video curation
```python
from nemo_curator.video import (
SceneDetector,
ClipExtractor,
InternVideo2Embedder
)
# Detect scenes
scene_detector = SceneDetector(threshold=27.0)
scenes = scene_detector(video_dataset)
# Extract clips
clip_extractor = ClipExtractor(min_duration=2.0, max_duration=10.0)
clips = clip_extractor(scenes)
# Generate embeddings
video_embedder = InternVideo2Embedder()
video_embeddings = video_embedder(clips)
```
### Audio curation
```python
from nemo_curator.audio import (
ASRInference,
WERFilter,
DurationFilter
)
# ASR transcription
asr = ASRInference(model="nvidia/stt_en_fastconformer_hybrid_large_pc")
transcribed = asr(audio_dataset)
# Filter by WER (word error rate)
wer_filter = WERFilter(max_wer=0.3)
high_quality_audio = wer_filter(transcribed)
# Duration filtering
duration_filter = DurationFilter(min_duration=1.0, max_duration=30.0)
filtered_audio = duration_filter(high_quality_audio)
```
## Common patterns
### Web scrape curation (Common Crawl)
```python
from nemo_curator import ScoreFilter, Modify
from nemo_curator.filters import *
from nemo_curator.modules import *
from nemo_curator.datasets import DocumentDataset
# Load Common Crawl data
dataset = DocumentDataset.read_parquet("common_crawl/*.parquet")
# Pipeline
pipeline = [
# 1. Quality filtering
WordCountFilter(min_words=100, max_words=50000),
RepeatedLinesFilter(max_repeated_line_fraction=0.2),
SymbolToWordRatioFilter(max_symbol_to_word_ratio=0.3),
UrlRatioFilter(max_url_ratio=0.3),
# 2. Language filtering
LanguageIdentificationFilter(target_languages=["en"]),
# 3. Deduplication
ExactDuplicates(id_field="id", text_field="text"),
FuzzyDuplicates(id_field="id", text_field="text", num_hashes=260),
# 4. PII redaction
PIIRedactor(),
# 5. NSFW filtering
NSFWClassifier(threshold=0.8)
]
# Execute
for stage in pipeline:
dataset = stage(dataset)
# Save
dataset.to_parquet("curated_common_crawl/")
```
### Distributed processing
```python
from nemo_curator import get_client
from dask_cuda import LocalCUDACluster
# Multi-GPU cluster
cluster = LocalCUDACluster(n_workers=8)
client = get_client(cluster=cluster)
# Process large dataset
dataset = DocumentDataset.read_parquet("s3://large_dataset/*.parquet")
deduped = FuzzyDuplicates(...)(dataset)
# Cleanup
client.close()
cluster.close()
```
## Performance benchmarks
### Fuzzy deduplication (8TB RedPajama v2)
- **CPU (256 cores)**: 120 hours
- **GPU (8× A100)**: 7.5 hours
- **Speedup**: 16×
### Exact deduplication (1TB)
- **CPU (64 cores)**: 8 hours
- **GPU (4× A100)**: 0.5 hours
- **Speedup**: 16×
### Quality filtering (100GB)
- **CPU (32 cores)**: 2 hours
- **GPU (2× A100)**: 0.2 hours
- **Speedup**: 10×
## Cost comparison
**CPU-based curation** (AWS c5.18xlarge × 10):
- Cost: $3.60/hour × 10 = $36/hour
- Time for 8TB: 120 hours
- **Total**: $4,320
**GPU-based curation** (AWS p4d.24xlarge × 2):
- Cost: $32.77/hour × 2 = $65.54/hour
- Time for 8TB: 7.5 hours
- **Total**: $491.55
**Savings**: 89% reduction ($3,828 saved)
## Supported data formats
- **Input**: Parquet, JSONL, CSV
- **Output**: Parquet (recommended), JSONL
- **WebDataset**: TAR archives for multi-modal
## Use cases
**Production deployments**:
- NVIDIA used NeMo Curator to prepare Nemotron-4 training data
- Open-source datasets curated: RedPajama v2, The Pile
## References
- **[Filtering Guide](references/filtering.md)** - 30+ quality filters, heuristics
- **[Deduplication Guide](references/deduplication.md)** - Exact, fuzzy, semantic methods
## Resources
- **GitHub**: https://github.com/NVIDIA/NeMo-Curator ⭐ 500+
- **Docs**: https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/
- **Version**: 0.4.0+
- **License**: Apache 2.0
@@ -0,0 +1,87 @@
# Deduplication Guide
Complete guide to exact, fuzzy, and semantic deduplication.
## Exact deduplication
Remove documents with identical content.
```python
from nemo_curator.modules import ExactDuplicates
# Exact deduplication
exact_dedup = ExactDuplicates(
id_field="id",
text_field="text",
hash_method="md5" # or "sha256"
)
deduped = exact_dedup(dataset)
```
**Performance**: ~16× faster on GPU vs CPU
## Fuzzy deduplication
Remove near-duplicate documents using MinHash + LSH.
```python
from nemo_curator.modules import FuzzyDuplicates
fuzzy_dedup = FuzzyDuplicates(
id_field="id",
text_field="text",
num_hashes=260, # MinHash permutations (more = accurate)
num_buckets=20, # LSH buckets (more = faster, less recall)
hash_method="md5",
jaccard_threshold=0.8 # Similarity threshold
)
deduped = fuzzy_dedup(dataset)
```
**Parameters**:
- `num_hashes`: 128-512 (default 260)
- `num_buckets`: 10-50 (default 20)
- `jaccard_threshold`: 0.7-0.9 (default 0.8)
**Performance**: 16× faster on 8TB dataset (120h → 7.5h)
## Semantic deduplication
Remove semantically similar documents using embeddings.
```python
from nemo_curator.modules import SemanticDuplicates
semantic_dedup = SemanticDuplicates(
id_field="id",
text_field="text",
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
embedding_batch_size=256,
threshold=0.85, # Cosine similarity threshold
device="cuda"
)
deduped = semantic_dedup(dataset)
```
**Models**:
- `all-MiniLM-L6-v2`: Fast, 384 dims
- `all-mpnet-base-v2`: Better quality, 768 dims
- Custom models supported
## Comparison
| Method | Speed | Recall | Use Case |
|--------|-------|--------|----------|
| Exact | Fastest | 100% | Exact matches only |
| Fuzzy | Fast | ~95% | Near-duplicates (recommended) |
| Semantic | Slow | ~90% | Paraphrases, rewrites |
## Best practices
1. **Start with exact dedup** - Remove obvious duplicates
2. **Use fuzzy for large datasets** - Best speed/quality trade-off
3. **Semantic for high-value data** - Expensive but thorough
4. **GPU acceleration required** - 10-16× speedup
@@ -0,0 +1,102 @@
# Quality Filtering Guide
Complete guide to NeMo Curator's 30+ quality filters.
## Text-based filters
### Word count
```python
from nemo_curator.filters import WordCountFilter
# Filter by word count
dataset = dataset.filter(WordCountFilter(min_words=50, max_words=100000))
```
### Repeated content
```python
from nemo_curator.filters import RepeatedLinesFilter
# Remove documents with >30% repeated lines
dataset = dataset.filter(RepeatedLinesFilter(max_repeated_line_fraction=0.3))
```
### Symbol ratio
```python
from nemo_curator.filters import SymbolToWordRatioFilter
# Remove documents with too many symbols
dataset = dataset.filter(SymbolToWordRatioFilter(max_symbol_to_word_ratio=0.3))
```
### URL ratio
```python
from nemo_curator.filters import UrlRatioFilter
# Remove documents with many URLs
dataset = dataset.filter(UrlRatioFilter(max_url_ratio=0.2))
```
## Language filtering
```python
from nemo_curator.filters import LanguageIdentificationFilter
# Keep only English documents
dataset = dataset.filter(LanguageIdentificationFilter(target_languages=["en"]))
# Multiple languages
dataset = dataset.filter(LanguageIdentificationFilter(target_languages=["en", "es", "fr"]))
```
## Classifier-based filtering
### Quality classifier
```python
from nemo_curator.classifiers import QualityClassifier
quality_clf = QualityClassifier(
model_path="nvidia/quality-classifier-deberta",
batch_size=256,
device="cuda"
)
# Filter low-quality (threshold > 0.5 = high quality)
dataset = dataset.filter(lambda doc: quality_clf(doc["text"]) > 0.5)
```
### NSFW classifier
```python
from nemo_curator.classifiers import NSFWClassifier
nsfw_clf = NSFWClassifier(threshold=0.9, device="cuda")
# Remove NSFW content
dataset = dataset.filter(lambda doc: nsfw_clf(doc["text"]) < 0.9)
```
## Heuristic filters
Full list of 30+ filters:
- WordCountFilter
- RepeatedLinesFilter
- UrlRatioFilter
- SymbolToWordRatioFilter
- NonAlphaNumericFilter
- BulletsFilter
- WhiteSpaceFilter
- ParenthesesFilter
- LongWordFilter
- And 20+ more...
## Best practices
1. **Apply cheap filters first** - Word count before GPU classifiers
2. **Tune thresholds on sample** - Test on 10k docs before full run
3. **Use GPU classifiers sparingly** - Expensive but effective
4. **Chain filters efficiently** - Order by cost (cheap → expensive)
@@ -0,0 +1,326 @@
---
name: ray-data
description: Scalable data processing for ML workloads. Streaming execution across CPU/GPU, supports Parquet/CSV/JSON/images. Integrates with Ray Train, PyTorch, TensorFlow. Scales from single machine to 100s of nodes. Use for batch inference, data preprocessing, multi-modal data loading, or distributed ETL pipelines.
version: 1.0.0
author: Orchestra Research
license: MIT
tags: [Data Processing, Ray Data, Distributed Computing, ML Pipelines, Batch Inference, ETL, Scalable, Ray, PyTorch, TensorFlow]
dependencies: ["ray[data]", pyarrow, pandas]
---
# Ray Data - Scalable ML Data Processing
Distributed data processing library for ML and AI workloads.
## When to use Ray Data
**Use Ray Data when:**
- Processing large datasets (>100GB) for ML training
- Need distributed data preprocessing across cluster
- Building batch inference pipelines
- Loading multi-modal data (images, audio, video)
- Scaling data processing from laptop to cluster
**Key features**:
- **Streaming execution**: Process data larger than memory
- **GPU support**: Accelerate transforms with GPUs
- **Framework integration**: PyTorch, TensorFlow, HuggingFace
- **Multi-modal**: Images, Parquet, CSV, JSON, audio, video
**Use alternatives instead**:
- **Pandas**: Small data (<1GB) on single machine
- **Dask**: Tabular data, SQL-like operations
- **Spark**: Enterprise ETL, SQL queries
## Quick start
### Installation
```bash
pip install -U 'ray[data]'
```
### Load and transform data
```python
import ray
# Read Parquet files
ds = ray.data.read_parquet("s3://bucket/data/*.parquet")
# Transform data (lazy execution)
ds = ds.map_batches(lambda batch: {"processed": batch["text"].str.lower()})
# Consume data
for batch in ds.iter_batches(batch_size=100):
print(batch)
```
### Integration with Ray Train
```python
import ray
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer
# Create dataset
train_ds = ray.data.read_parquet("s3://bucket/train/*.parquet")
def train_func(config):
# Access dataset in training
train_ds = ray.train.get_dataset_shard("train")
for epoch in range(10):
for batch in train_ds.iter_batches(batch_size=32):
# Train on batch
pass
# Train with Ray
trainer = TorchTrainer(
train_func,
datasets={"train": train_ds},
scaling_config=ScalingConfig(num_workers=4, use_gpu=True)
)
trainer.fit()
```
## Reading data
### From cloud storage
```python
import ray
# Parquet (recommended for ML)
ds = ray.data.read_parquet("s3://bucket/data/*.parquet")
# CSV
ds = ray.data.read_csv("s3://bucket/data/*.csv")
# JSON
ds = ray.data.read_json("gs://bucket/data/*.json")
# Images
ds = ray.data.read_images("s3://bucket/images/")
```
### From Python objects
```python
# From list
ds = ray.data.from_items([{"id": i, "value": i * 2} for i in range(1000)])
# From range
ds = ray.data.range(1000000) # Synthetic data
# From pandas
import pandas as pd
df = pd.DataFrame({"col1": [1, 2, 3], "col2": [4, 5, 6]})
ds = ray.data.from_pandas(df)
```
## Transformations
### Map batches (vectorized)
```python
# Batch transformation (fast)
def process_batch(batch):
batch["doubled"] = batch["value"] * 2
return batch
ds = ds.map_batches(process_batch, batch_size=1000)
```
### Row transformations
```python
# Row-by-row (slower)
def process_row(row):
row["squared"] = row["value"] ** 2
return row
ds = ds.map(process_row)
```
### Filter
```python
# Filter rows
ds = ds.filter(lambda row: row["value"] > 100)
```
### Group by and aggregate
```python
# Group by column
ds = ds.groupby("category").count()
# Custom aggregation
ds = ds.groupby("category").map_groups(lambda group: {"sum": group["value"].sum()})
```
## GPU-accelerated transforms
```python
# Use GPU for preprocessing
def preprocess_images_gpu(batch):
import torch
images = torch.tensor(batch["image"]).cuda()
# GPU preprocessing
processed = images * 255
return {"processed": processed.cpu().numpy()}
ds = ds.map_batches(
preprocess_images_gpu,
batch_size=64,
num_gpus=1 # Request GPU
)
```
## Writing data
```python
# Write to Parquet
ds.write_parquet("s3://bucket/output/")
# Write to CSV
ds.write_csv("output/")
# Write to JSON
ds.write_json("output/")
```
## Performance optimization
### Repartition
```python
# Control parallelism
ds = ds.repartition(100) # 100 blocks for 100-core cluster
```
### Batch size tuning
```python
# Larger batches = faster vectorized ops
ds.map_batches(process_fn, batch_size=10000) # vs batch_size=100
```
### Streaming execution
```python
# Process data larger than memory
ds = ray.data.read_parquet("s3://huge-dataset/")
for batch in ds.iter_batches(batch_size=1000):
process(batch) # Streamed, not loaded to memory
```
## Common patterns
### Batch inference
```python
import ray
# Load model
def load_model():
# Load once per worker
return MyModel()
# Inference function
class BatchInference:
def __init__(self):
self.model = load_model()
def __call__(self, batch):
predictions = self.model(batch["input"])
return {"prediction": predictions}
# Run distributed inference
ds = ray.data.read_parquet("s3://data/")
predictions = ds.map_batches(BatchInference, batch_size=32, num_gpus=1)
predictions.write_parquet("s3://output/")
```
### Data preprocessing pipeline
```python
# Multi-step pipeline
ds = (
ray.data.read_parquet("s3://raw/")
.map_batches(clean_data)
.map_batches(tokenize)
.map_batches(augment)
.write_parquet("s3://processed/")
)
```
## Integration with ML frameworks
### PyTorch
```python
# Convert to PyTorch
torch_ds = ds.to_torch(label_column="label", batch_size=32)
for batch in torch_ds:
# batch is dict with tensors
inputs, labels = batch["features"], batch["label"]
```
### TensorFlow
```python
# Convert to TensorFlow
tf_ds = ds.to_tf(feature_columns=["image"], label_column="label", batch_size=32)
for features, labels in tf_ds:
# Train model
pass
```
## Supported data formats
| Format | Read | Write | Use Case |
|--------|------|-------|----------|
| Parquet | ✅ | ✅ | ML data (recommended) |
| CSV | ✅ | ✅ | Tabular data |
| JSON | ✅ | ✅ | Semi-structured |
| Images | ✅ | ❌ | Computer vision |
| NumPy | ✅ | ✅ | Arrays |
| Pandas | ✅ | ❌ | DataFrames |
## Performance benchmarks
**Scaling** (processing 100GB data):
- 1 node (16 cores): ~30 minutes
- 4 nodes (64 cores): ~8 minutes
- 16 nodes (256 cores): ~2 minutes
**GPU acceleration** (image preprocessing):
- CPU only: 1,000 images/sec
- 1 GPU: 5,000 images/sec
- 4 GPUs: 18,000 images/sec
## Use cases
**Production deployments**:
- **Pinterest**: Last-mile data processing for model training
- **ByteDance**: Scaling offline inference with multi-modal LLMs
- **Spotify**: ML platform for batch inference
## References
- **[Transformations Guide](references/transformations.md)** - Map, filter, groupby operations
- **[Integration Guide](references/integration.md)** - Ray Train, PyTorch, TensorFlow
## Resources
- **Docs**: https://docs.ray.io/en/latest/data/data.html
- **GitHub**: https://github.com/ray-project/ray ⭐ 36,000+
- **Version**: Ray 2.40.0+
- **Examples**: https://docs.ray.io/en/latest/data/examples/overview.html
@@ -0,0 +1,82 @@
# Ray Data Integration Guide
Integration with Ray Train and ML frameworks.
## Ray Train integration
### Basic training with datasets
```python
import ray
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer
# Create datasets
train_ds = ray.data.read_parquet("s3://data/train/")
val_ds = ray.data.read_parquet("s3://data/val/")
def train_func(config):
# Get dataset shards
train_ds = ray.train.get_dataset_shard("train")
val_ds = ray.train.get_dataset_shard("val")
for epoch in range(config["epochs"]):
# Iterate over batches
for batch in train_ds.iter_batches(batch_size=32):
# Train on batch
pass
# Launch training
trainer = TorchTrainer(
train_func,
train_loop_config={"epochs": 10},
datasets={"train": train_ds, "val": val_ds},
scaling_config=ScalingConfig(num_workers=4, use_gpu=True)
)
result = trainer.fit()
```
## PyTorch integration
### Convert to PyTorch Dataset
```python
# Option 1: to_torch (recommended)
torch_ds = ds.to_torch(
label_column="label",
batch_size=32,
drop_last=True
)
for batch in torch_ds:
inputs = batch["features"]
labels = batch["label"]
# Train model
# Option 2: iter_torch_batches
for batch in ds.iter_torch_batches(batch_size=32):
# batch is dict of tensors
pass
```
## TensorFlow integration
```python
tf_ds = ds.to_tf(
feature_columns=["image", "text"],
label_column="label",
batch_size=32
)
for features, labels in tf_ds:
# Train TensorFlow model
pass
```
## Best practices
1. **Shard datasets in Ray Train** - Automatic with `get_dataset_shard()`
2. **Use streaming** - Don't load entire dataset to memory
3. **Preprocess in Ray Data** - Distribute preprocessing across cluster
4. **Cache preprocessed data** - Write to Parquet, read in training
@@ -0,0 +1,83 @@
# Ray Data Transformations
Complete guide to data transformations in Ray Data.
## Core operations
### Map batches (vectorized)
```python
# Recommended for performance
def process_batch(batch):
# batch is dict of numpy arrays or pandas Series
batch["doubled"] = batch["value"] * 2
return batch
ds = ds.map_batches(process_batch, batch_size=1000)
```
**Performance**: 10-100× faster than row-by-row
### Map (row-by-row)
```python
# Use only when vectorization not possible
def process_row(row):
row["squared"] = row["value"] ** 2
return row
ds = ds.map(process_row)
```
### Filter
```python
# Remove rows
ds = ds.filter(lambda row: row["score"] > 0.5)
```
### Flat map
```python
# One row → multiple rows
def expand_row(row):
return [{"value": row["value"] + i} for i in range(3)]
ds = ds.flat_map(expand_row)
```
## GPU-accelerated transforms
```python
def gpu_transform(batch):
import torch
data = torch.tensor(batch["data"]).cuda()
# GPU processing
result = data * 2
return {"processed": result.cpu().numpy()}
ds = ds.map_batches(gpu_transform, num_gpus=1, batch_size=64)
```
## Groupby operations
```python
# Group by column
grouped = ds.groupby("category")
# Aggregate
result = grouped.count()
# Custom aggregation
result = grouped.map_groups(lambda group: {
"sum": group["value"].sum(),
"mean": group["value"].mean()
})
```
## Best practices
1. **Use map_batches over map** - 10-100× faster
2. **Tune batch_size** - Larger = faster (balance with memory)
3. **Use GPUs for heavy compute** - Image/audio preprocessing
4. **Stream large datasets** - Use iter_batches for >memory data
@@ -0,0 +1,97 @@
# GRPO/RL Training Skill
**Expert-level guidance for Group Relative Policy Optimization with TRL**
## 📁 Skill Structure
```
grpo-rl-training/
├── SKILL.md # Main skill documentation (READ THIS FIRST)
├── README.md # This file
├── templates/
│ └── basic_grpo_training.py # Production-ready training template
└── examples/
└── reward_functions_library.py # 20+ reward function examples
```
## 🚀 Quick Start
1. **Read SKILL.md** - Comprehensive guide with all concepts and patterns
2. **Copy `templates/basic_grpo_training.py`** - Start with working code
3. **Browse `examples/reward_functions_library.py`** - Pick reward functions for your task
4. **Modify for your use case** - Adapt dataset, rewards, and config
## 💡 What's Inside
### SKILL.md (Main Documentation)
- Core GRPO concepts and algorithm fundamentals
- Complete implementation workflow (dataset → rewards → training → deployment)
- 10+ reward function examples with code
- Hyperparameter tuning guide
- Training insights (loss behavior, metrics, debugging)
- Troubleshooting guide
- Production best practices
### Templates
- **basic_grpo_training.py**: Minimal, production-ready training script
- Uses Qwen 2.5 1.5B Instruct
- 3 reward functions (format + correctness)
- LoRA for efficient training
- Fully documented and ready to run
### Examples
- **reward_functions_library.py**: 20+ battle-tested reward functions
- Correctness rewards (exact match, fuzzy match, numeric, code execution)
- Format rewards (XML, JSON, strict/soft)
- Length rewards (ideal length, min/max)
- Style rewards (reasoning quality, citations, repetition penalty)
- Combined rewards (multi-objective optimization)
- Preset collections for common tasks
## 📖 Usage for Agents
When this skill is loaded in your agent's context:
1. **Always read SKILL.md first** before implementing
2. **Start simple** - Use length-based reward to validate setup
3. **Build incrementally** - Add one reward function at a time
4. **Reference examples** - Copy patterns from reward_functions_library.py
5. **Monitor training** - Watch reward metrics (not loss!)
## 🎯 Common Use Cases
| Task Type | Recommended Rewards | Template |
|-----------|---------------------|----------|
| Math reasoning | `MATH_REASONING_REWARDS` preset | basic_grpo_training.py |
| Code generation | `CODE_GENERATION_REWARDS` preset | Modify dataset in template |
| Summarization | `SUMMARIZATION_REWARDS` preset | Adjust prompts + rewards |
| Q&A | `QA_REWARDS` preset | Use fuzzy match + citations |
## ⚠️ Critical Reminders
- **Loss goes UP during training** - This is normal (it's KL divergence)
- **Use 3-5 reward functions** - Single rewards often fail
- **Test rewards before training** - Debug each function independently
- **Monitor reward_std** - Should stay > 0.1 (avoid mode collapse)
- **Start with num_generations=4-8** - Scale up if GPU allows
## 🔗 External Resources
- [TRL Documentation](https://huggingface.co/docs/trl)
- [DeepSeek R1 Paper](https://arxiv.org/abs/2501.12948)
- [Open R1 Implementation](https://github.com/huggingface/open-r1)
- [Unsloth (2-3x faster)](https://docs.unsloth.ai/)
## 📝 Version
**v1.0.0** - Initial release (January 2025)
## 👨‍💻 Maintained By
Orchestra Research
For questions or improvements, see https://orchestra.com
---
**License:** MIT
**Last Updated:** January 2025
@@ -0,0 +1,572 @@
---
name: grpo-rl-training
description: Expert guidance for GRPO/RL fine-tuning with TRL for reasoning and task-specific model training
version: 1.0.0
author: Orchestra Research
license: MIT
tags: [Post-Training, Reinforcement Learning, GRPO, TRL, RLHF, Reward Modeling, Reasoning, DPO, PPO, Structured Output]
dependencies: [transformers>=4.47.0, trl>=0.14.0, datasets>=3.2.0, peft>=0.14.0, torch]
---
# GRPO/RL Training with TRL
Expert-level guidance for implementing Group Relative Policy Optimization (GRPO) using the Transformer Reinforcement Learning (TRL) library. This skill provides battle-tested patterns, critical insights, and production-ready workflows for fine-tuning language models with custom reward functions.
## When to Use This Skill
Use GRPO training when you need to:
- **Enforce specific output formats** (e.g., XML tags, JSON, structured reasoning)
- **Teach verifiable tasks** with objective correctness metrics (math, coding, fact-checking)
- **Improve reasoning capabilities** by rewarding chain-of-thought patterns
- **Align models to domain-specific behaviors** without labeled preference data
- **Optimize for multiple objectives** simultaneously (format + correctness + style)
**Do NOT use GRPO for:**
- Simple supervised fine-tuning tasks (use SFT instead)
- Tasks without clear reward signals
- When you already have high-quality preference pairs (use DPO/PPO instead)
---
## Core Concepts
### 1. GRPO Algorithm Fundamentals
**Key Mechanism:**
- Generates **multiple completions** for each prompt (group size: 4-16)
- Compares completions within each group using reward functions
- Updates policy to favor higher-rewarded responses relative to the group
**Critical Difference from PPO:**
- No separate reward model needed
- More sample-efficient (learns from within-group comparisons)
- Simpler to implement and debug
**Mathematical Intuition:**
```
For each prompt p:
1. Generate N completions: {c₁, c₂, ..., cₙ}
2. Compute rewards: {r₁, r₂, ..., rₙ}
3. Learn to increase probability of high-reward completions
relative to low-reward ones in the same group
```
### 2. Reward Function Design Philosophy
**Golden Rules:**
1. **Compose multiple reward functions** - Each handles one aspect (format, correctness, style)
2. **Scale rewards appropriately** - Higher weight = stronger signal
3. **Use incremental rewards** - Partial credit for partial compliance
4. **Test rewards independently** - Debug each reward function in isolation
**Reward Function Types:**
| Type | Use Case | Example Weight |
|------|----------|----------------|
| **Correctness** | Verifiable tasks (math, code) | 2.0 (highest) |
| **Format** | Strict structure enforcement | 0.5-1.0 |
| **Length** | Encourage verbosity/conciseness | 0.1-0.5 |
| **Style** | Penalize unwanted patterns | -0.5 to 0.5 |
---
## Implementation Workflow
### Step 1: Dataset Preparation
**Critical Requirements:**
- Prompts in chat format (list of dicts with 'role' and 'content')
- Include system prompts to set expectations
- For verifiable tasks, include ground truth answers as additional columns
**Example Structure:**
```python
from datasets import load_dataset, Dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
[Your step-by-step thinking]
</reasoning>
<answer>
[Final answer]
</answer>
"""
def prepare_dataset(raw_data):
"""
Transform raw data into GRPO-compatible format.
Returns: Dataset with columns:
- 'prompt': List[Dict] with role/content (system + user messages)
- 'answer': str (ground truth, optional but recommended)
"""
return raw_data.map(lambda x: {
'prompt': [
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': x['question']}
],
'answer': extract_answer(x['raw_answer'])
})
```
**Pro Tips:**
- Use one-shot or few-shot examples in system prompt for complex formats
- Keep prompts concise (max_prompt_length: 256-512 tokens)
- Validate data quality before training (garbage in = garbage out)
### Step 2: Reward Function Implementation
**Template Structure:**
```python
def reward_function_name(
prompts, # List[List[Dict]]: Original prompts
completions, # List[List[Dict]]: Model generations
answer=None, # Optional: Ground truth from dataset
**kwargs # Additional dataset columns
) -> list[float]:
"""
Evaluate completions and return rewards.
Returns: List of floats (one per completion)
"""
# Extract completion text
responses = [comp[0]['content'] for comp in completions]
# Compute rewards
rewards = []
for response in responses:
score = compute_score(response)
rewards.append(score)
return rewards
```
**Example 1: Correctness Reward (Math/Coding)**
```python
def correctness_reward(prompts, completions, answer, **kwargs):
"""Reward correct answers with high score."""
responses = [comp[0]['content'] for comp in completions]
extracted = [extract_final_answer(r) for r in responses]
return [2.0 if ans == gt else 0.0
for ans, gt in zip(extracted, answer)]
```
**Example 2: Format Reward (Structured Output)**
```python
import re
def format_reward(completions, **kwargs):
"""Reward XML-like structured format."""
pattern = r'<reasoning>.*?</reasoning>\s*<answer>.*?</answer>'
responses = [comp[0]['content'] for comp in completions]
return [1.0 if re.search(pattern, r, re.DOTALL) else 0.0
for r in responses]
```
**Example 3: Incremental Format Reward (Partial Credit)**
```python
def incremental_format_reward(completions, **kwargs):
"""Award partial credit for format compliance."""
responses = [comp[0]['content'] for comp in completions]
rewards = []
for r in responses:
score = 0.0
if '<reasoning>' in r:
score += 0.25
if '</reasoning>' in r:
score += 0.25
if '<answer>' in r:
score += 0.25
if '</answer>' in r:
score += 0.25
# Penalize extra text after closing tag
if r.count('</answer>') == 1:
extra_text = r.split('</answer>')[-1].strip()
score -= len(extra_text) * 0.001
rewards.append(score)
return rewards
```
**Critical Insight:**
Combine 3-5 reward functions for robust training. Order matters less than diversity of signals.
### Step 3: Training Configuration
**Memory-Optimized Config (Small GPU)**
```python
from trl import GRPOConfig
training_args = GRPOConfig(
output_dir="outputs/grpo-model",
# Learning rate
learning_rate=5e-6, # Lower = more stable
adam_beta1=0.9,
adam_beta2=0.99,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type='cosine',
# Batch settings
per_device_train_batch_size=1,
gradient_accumulation_steps=4, # Effective batch = 4
# GRPO-specific
num_generations=8, # Group size: 8-16 recommended
max_prompt_length=256,
max_completion_length=512,
# Training duration
num_train_epochs=1,
max_steps=None, # Or set fixed steps (e.g., 500)
# Optimization
bf16=True, # Faster on A100/H100
optim="adamw_8bit", # Memory-efficient optimizer
max_grad_norm=0.1,
# Logging
logging_steps=1,
save_steps=100,
report_to="wandb", # Or "none" for no logging
)
```
**High-Performance Config (Large GPU)**
```python
training_args = GRPOConfig(
output_dir="outputs/grpo-model",
learning_rate=1e-5,
per_device_train_batch_size=4,
gradient_accumulation_steps=2,
num_generations=16, # Larger groups = better signal
max_prompt_length=512,
max_completion_length=1024,
num_train_epochs=1,
bf16=True,
use_vllm=True, # Fast generation with vLLM
logging_steps=10,
)
```
**Critical Hyperparameters:**
| Parameter | Impact | Tuning Advice |
|-----------|--------|---------------|
| `num_generations` | Group size for comparison | Start with 8, increase to 16 if GPU allows |
| `learning_rate` | Convergence speed/stability | 5e-6 (safe), 1e-5 (faster, riskier) |
| `max_completion_length` | Output verbosity | Match your task (512 for reasoning, 256 for short answers) |
| `gradient_accumulation_steps` | Effective batch size | Increase if GPU memory limited |
### Step 4: Model Setup and Training
**Standard Setup (Transformers)**
```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig
from trl import GRPOTrainer
# Load model
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2", # 2-3x faster
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
# Optional: LoRA for parameter-efficient training
peft_config = LoraConfig(
r=16, # Rank (higher = more capacity)
lora_alpha=32, # Scaling factor (typically 2*r)
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"
],
task_type="CAUSAL_LM",
lora_dropout=0.05,
)
# Initialize trainer
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[
incremental_format_reward,
format_reward,
correctness_reward,
],
args=training_args,
train_dataset=dataset,
peft_config=peft_config, # Remove for full fine-tuning
)
# Train
trainer.train()
# Save
trainer.save_model("final_model")
```
**Unsloth Setup (2-3x Faster)**
```python
from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="google/gemma-3-1b-it",
max_seq_length=1024,
load_in_4bit=True,
fast_inference=True,
max_lora_rank=32,
)
model = FastLanguageModel.get_peft_model(
model,
r=32,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha=32,
use_gradient_checkpointing="unsloth",
)
# Rest is identical to standard setup
trainer = GRPOTrainer(model=model, ...)
trainer.train()
```
---
## Critical Training Insights
### 1. Loss Behavior (EXPECTED PATTERN)
- **Loss starts near 0 and INCREASES during training**
- This is CORRECT - loss measures KL divergence from initial policy
- Model is learning (diverging from original behavior to optimize rewards)
- Monitor reward metrics instead of loss for progress
### 2. Reward Tracking
Key metrics to watch:
- `reward`: Average across all completions
- `reward_std`: Diversity within groups (should remain > 0)
- `kl`: KL divergence from reference (should grow moderately)
**Healthy Training Pattern:**
```
Step Reward Reward_Std KL
100 0.5 0.3 0.02
200 0.8 0.25 0.05
300 1.2 0.2 0.08 ← Good progression
400 1.5 0.15 0.12
```
**Warning Signs:**
- Reward std → 0 (model collapsing to single response)
- KL exploding (> 0.5) (diverging too much, reduce LR)
- Reward stuck (reward functions too harsh or model capacity issue)
### 3. Common Pitfalls and Solutions
| Problem | Symptom | Solution |
|---------|---------|----------|
| **Mode collapse** | All completions identical | Increase `num_generations`, add diversity penalty |
| **No learning** | Flat rewards | Check reward function logic, increase LR |
| **OOM errors** | GPU memory exceeded | Reduce `num_generations`, enable gradient checkpointing |
| **Slow training** | < 1 it/s | Enable `use_vllm=True`, use Unsloth, reduce seq length |
| **Format ignored** | Model doesn't follow structure | Increase format reward weight, add incremental rewards |
---
## Advanced Patterns
### 1. Multi-Stage Training
For complex tasks, train in stages:
```python
# Stage 1: Format compliance (epochs=1)
trainer_stage1 = GRPOTrainer(
model=model,
reward_funcs=[incremental_format_reward, format_reward],
...
)
trainer_stage1.train()
# Stage 2: Correctness (epochs=1)
trainer_stage2 = GRPOTrainer(
model=model,
reward_funcs=[format_reward, correctness_reward],
...
)
trainer_stage2.train()
```
### 2. Adaptive Reward Scaling
```python
class AdaptiveReward:
def __init__(self, base_reward_func, initial_weight=1.0):
self.func = base_reward_func
self.weight = initial_weight
def __call__(self, *args, **kwargs):
rewards = self.func(*args, **kwargs)
return [r * self.weight for r in rewards]
def adjust_weight(self, success_rate):
"""Increase weight if model struggling, decrease if succeeding."""
if success_rate < 0.3:
self.weight *= 1.2
elif success_rate > 0.8:
self.weight *= 0.9
```
### 3. Custom Dataset Integration
```python
def load_custom_knowledge_base(csv_path):
"""Example: School communication platform docs."""
import pandas as pd
df = pd.read_csv(csv_path)
dataset = Dataset.from_pandas(df).map(lambda x: {
'prompt': [
{'role': 'system', 'content': CUSTOM_SYSTEM_PROMPT},
{'role': 'user', 'content': x['question']}
],
'answer': x['expert_answer']
})
return dataset
```
---
## Deployment and Inference
### Save and Merge LoRA
```python
# Merge LoRA adapters into base model
if hasattr(trainer.model, 'merge_and_unload'):
merged_model = trainer.model.merge_and_unload()
merged_model.save_pretrained("production_model")
tokenizer.save_pretrained("production_model")
```
### Inference Example
```python
from transformers import pipeline
generator = pipeline(
"text-generation",
model="production_model",
tokenizer=tokenizer
)
result = generator(
[
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': "What is 15 + 27?"}
],
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_p=0.9
)
print(result[0]['generated_text'])
```
---
## Best Practices Checklist
**Before Training:**
- [ ] Validate dataset format (prompts as List[Dict])
- [ ] Test reward functions on sample data
- [ ] Calculate expected max_prompt_length from data
- [ ] Choose appropriate num_generations based on GPU memory
- [ ] Set up logging (wandb recommended)
**During Training:**
- [ ] Monitor reward progression (should increase)
- [ ] Check reward_std (should stay > 0.1)
- [ ] Watch for OOM errors (reduce batch size if needed)
- [ ] Sample generations every 50-100 steps
- [ ] Validate format compliance on holdout set
**After Training:**
- [ ] Merge LoRA weights if using PEFT
- [ ] Test on diverse prompts
- [ ] Compare to baseline model
- [ ] Document reward weights and hyperparameters
- [ ] Save reproducibility config
---
## Troubleshooting Guide
### Debugging Workflow
1. **Isolate reward functions** - Test each independently
2. **Check data distribution** - Ensure diversity in prompts
3. **Reduce complexity** - Start with single reward, add gradually
4. **Monitor generations** - Print samples every N steps
5. **Validate extraction logic** - Ensure answer parsing works
### Quick Fixes
```python
# Debug reward function
def debug_reward(completions, **kwargs):
responses = [comp[0]['content'] for comp in completions]
for i, r in enumerate(responses[:2]): # Print first 2
print(f"Response {i}: {r[:200]}...")
return [1.0] * len(responses) # Dummy rewards
# Test without training
trainer = GRPOTrainer(..., reward_funcs=[debug_reward])
trainer.generate_completions(dataset[:1]) # Generate without updating
```
---
## References and Resources
**Official Documentation:**
- TRL GRPO Trainer: https://huggingface.co/docs/trl/grpo_trainer
- DeepSeek R1 Paper: https://arxiv.org/abs/2501.12948
- Unsloth Docs: https://docs.unsloth.ai/
**Example Repositories:**
- Open R1 Implementation: https://github.com/huggingface/open-r1
- TRL Examples: https://github.com/huggingface/trl/tree/main/examples
**Recommended Reading:**
- Progressive Disclosure Pattern for agent instructions
- Reward shaping in RL (Ng et al.)
- LoRA paper (Hu et al., 2021)
---
## Usage Instructions for Agents
When this skill is loaded:
1. **Read this entire file** before implementing GRPO training
2. **Start with the simplest reward function** (e.g., length-based) to validate setup
3. **Use the templates** in `templates/` directory as starting points
4. **Reference examples** in `examples/` for task-specific implementations
5. **Follow the workflow** sequentially (don't skip steps)
6. **Debug incrementally** - add one reward function at a time
**Critical Reminders:**
- Always use multiple reward functions (3-5 is optimal)
- Monitor reward metrics, not loss
- Test reward functions before training
- Start small (num_generations=4), scale up gradually
- Save checkpoints frequently (every 100 steps)
This skill is designed for **expert-level implementation**. Beginners should start with supervised fine-tuning before attempting GRPO.
@@ -0,0 +1,393 @@
"""
GRPO Reward Functions Library
===============================
A collection of battle-tested reward functions for common GRPO training scenarios.
Copy and adapt these for your specific use case.
Categories:
- Correctness rewards (verifiable tasks)
- Format rewards (structured output)
- Length rewards (verbosity control)
- Style rewards (quality and tone)
- Combined rewards (multi-objective)
"""
import re
from typing import List, Any
# ==================== CORRECTNESS REWARDS ====================
def exact_match_reward(prompts, completions, answer, **kwargs) -> List[float]:
"""
Binary reward for exact answer match.
Use for: Math problems, factual Q&A, code output
Weight: 2.0 (highest priority)
"""
responses = [comp[0]['content'] for comp in completions]
extracted = [extract_answer(r) for r in responses]
return [2.0 if ans.strip() == gt.strip() else 0.0
for ans, gt in zip(extracted, answer)]
def fuzzy_match_reward(prompts, completions, answer, **kwargs) -> List[float]:
"""
Partial credit for similar answers.
Use for: Open-ended answers, summaries
Weight: 1.0
"""
from difflib import SequenceMatcher
responses = [comp[0]['content'] for comp in completions]
extracted = [extract_answer(r) for r in responses]
rewards = []
for ans, gt in zip(extracted, answer):
similarity = SequenceMatcher(None, ans.lower(), gt.lower()).ratio()
rewards.append(similarity)
return rewards
def numeric_correctness_reward(prompts, completions, answer, tolerance=0.01, **kwargs) -> List[float]:
"""
Reward numeric answers within tolerance.
Use for: Math, physics, engineering problems
Weight: 2.0
"""
responses = [comp[0]['content'] for comp in completions]
extracted = [extract_answer(r) for r in responses]
rewards = []
for ans, gt in zip(extracted, answer):
try:
ans_num = float(ans.replace(',', ''))
gt_num = float(gt.replace(',', ''))
if abs(ans_num - gt_num) / max(abs(gt_num), 1e-8) <= tolerance:
rewards.append(2.0)
else:
rewards.append(0.0)
except:
rewards.append(0.0)
return rewards
def code_execution_reward(prompts, completions, test_cases, **kwargs) -> List[float]:
"""
Execute code and verify against test cases.
Use for: Code generation tasks
Weight: 2.0
"""
responses = [comp[0]['content'] for comp in completions]
extracted_code = [extract_code_block(r) for r in responses]
rewards = []
for code in extracted_code:
try:
# Execute code (sandboxed!)
passed = run_test_cases(code, test_cases)
rewards.append(2.0 if passed else 0.0)
except:
rewards.append(0.0)
return rewards
# ==================== FORMAT REWARDS ====================
def strict_xml_format_reward(completions, **kwargs) -> List[float]:
"""
Strict XML format: exact newlines and spacing.
Use for: When format must be EXACTLY specified
Weight: 0.5
"""
pattern = r'^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$'
responses = [comp[0]['content'] for comp in completions]
matches = [re.match(pattern, r, re.DOTALL) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def soft_xml_format_reward(completions, **kwargs) -> List[float]:
"""
Relaxed XML format: allows whitespace variations.
Use for: When structure matters more than exact spacing
Weight: 0.5
"""
pattern = r'<reasoning>.*?</reasoning>\s*<answer>.*?</answer>'
responses = [comp[0]['content'] for comp in completions]
matches = [re.search(pattern, r, re.DOTALL) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def json_format_reward(completions, **kwargs) -> List[float]:
"""
Reward valid JSON output.
Use for: Structured data extraction, API responses
Weight: 0.5
"""
import json
responses = [comp[0]['content'] for comp in completions]
rewards = []
for r in responses:
try:
json.loads(r)
rewards.append(0.5)
except:
rewards.append(0.0)
return rewards
def incremental_format_reward(completions, tags=['reasoning', 'answer'], **kwargs) -> List[float]:
"""
Partial credit for each required tag.
Use for: Training models to gradually learn format
Weight: sum(0.125 * num_tags * 2) = up to 0.5 for 2 tags
"""
responses = [comp[0]['content'] for comp in completions]
rewards = []
for r in responses:
score = 0.0
for tag in tags:
if f'<{tag}>' in r:
score += 0.125
if f'</{tag}>' in r:
score += 0.125
# Penalize extra content after final closing tag
if f'</{tags[-1]}>' in r:
extra = r.split(f'</{tags[-1]}>')[-1].strip()
score -= len(extra) * 0.001
rewards.append(score)
return rewards
# ==================== LENGTH REWARDS ====================
def ideal_length_reward(completions, ideal_tokens=100, **kwargs) -> List[float]:
"""
Reward responses near ideal length.
Use for: Controlling verbosity
Weight: 0.3
"""
responses = [comp[0]['content'] for comp in completions]
rewards = []
for r in responses:
length = len(r.split())
distance = abs(length - ideal_tokens)
# Gaussian-like reward peaking at ideal length
reward = 0.3 * max(0, 1 - distance / ideal_tokens)
rewards.append(reward)
return rewards
def min_length_reward(completions, min_tokens=50, **kwargs) -> List[float]:
"""
Penalize responses that are too short.
Use for: Ensuring detailed explanations
Weight: 0.2
"""
responses = [comp[0]['content'] for comp in completions]
rewards = []
for r in responses:
length = len(r.split())
reward = 0.2 if length >= min_tokens else -0.2
rewards.append(reward)
return rewards
def max_length_penalty(completions, max_tokens=500, **kwargs) -> List[float]:
"""
Penalize excessively long responses.
Use for: Preventing rambling
Weight: -0.3 when violated
"""
responses = [comp[0]['content'] for comp in completions]
rewards = []
for r in responses:
length = len(r.split())
reward = -0.3 if length > max_tokens else 0.0
rewards.append(reward)
return rewards
# ==================== STYLE REWARDS ====================
def reasoning_quality_reward(completions, **kwargs) -> List[float]:
"""
Reward detailed reasoning with logical connectors.
Use for: Improving chain-of-thought quality
Weight: 0.3
"""
logical_words = ['therefore', 'thus', 'because', 'since', 'consequently',
'first', 'second', 'next', 'finally', 'however']
responses = [comp[0]['content'] for comp in completions]
rewards = []
for r in responses:
reasoning = extract_xml_tag(r, 'reasoning').lower()
# Count logical connectors
count = sum(1 for word in logical_words if word in reasoning)
# Normalize by length
score = min(0.3, count * 0.05)
rewards.append(score)
return rewards
def citation_reward(completions, **kwargs) -> List[float]:
"""
Reward responses with citations or references.
Use for: Research tasks, fact-checking
Weight: 0.2
"""
citation_patterns = [
r'\[\d+\]', # [1], [2]
r'\([A-Z][a-z]+,?\s+\d{4}\)', # (Smith, 2020)
r'according to',
r'as stated in',
]
responses = [comp[0]['content'] for comp in completions]
rewards = []
for r in responses:
has_citation = any(re.search(pattern, r) for pattern in citation_patterns)
rewards.append(0.2 if has_citation else 0.0)
return rewards
def no_repetition_penalty(completions, **kwargs) -> List[float]:
"""
Penalize repetitive text (same phrase repeated).
Use for: Improving output diversity
Weight: -0.3 when repetitive
"""
responses = [comp[0]['content'] for comp in completions]
rewards = []
for r in responses:
words = r.lower().split()
# Check for repeated trigrams
trigrams = [' '.join(words[i:i+3]) for i in range(len(words)-2)]
unique_ratio = len(set(trigrams)) / max(len(trigrams), 1)
reward = -0.3 if unique_ratio < 0.7 else 0.0
rewards.append(reward)
return rewards
# ==================== COMBINED REWARDS ====================
def math_problem_reward(prompts, completions, answer, **kwargs) -> List[float]:
"""
Combined reward for math problems: format + correctness.
Automatically balances multiple objectives.
Weight: 2.5 total
"""
format_rewards = soft_xml_format_reward(completions)
correctness_rewards = exact_match_reward(prompts, completions, answer)
return [f + c for f, c in zip(format_rewards, correctness_rewards)]
def code_generation_reward(prompts, completions, test_cases, **kwargs) -> List[float]:
"""
Combined reward for code: format + execution + style.
Weight: 2.7 total
"""
code_format_rewards = code_block_format_reward(completions)
execution_rewards = code_execution_reward(prompts, completions, test_cases)
no_error_rewards = no_syntax_error_reward(completions)
return [f + e + s for f, e, s in zip(code_format_rewards, execution_rewards, no_error_rewards)]
# ==================== HELPER FUNCTIONS ====================
def extract_answer(text: str) -> str:
"""Extract content from <answer> tags."""
return extract_xml_tag(text, 'answer')
def extract_xml_tag(text: str, tag: str) -> str:
"""Generic XML tag extraction."""
pattern = f'<{tag}>(.*?)</{tag}>'
match = re.search(pattern, text, re.DOTALL)
return match.group(1).strip() if match else ""
def extract_code_block(text: str) -> str:
"""Extract code from markdown code blocks."""
pattern = r'```(?:python)?\n(.*?)\n```'
match = re.search(pattern, text, re.DOTALL)
return match.group(1) if match else ""
def run_test_cases(code: str, test_cases: List[tuple]) -> bool:
"""
Execute code with test cases (MUST be sandboxed in production!).
Args:
code: Python code string
test_cases: List of (input, expected_output) tuples
Returns:
True if all tests pass
"""
# WARNING: This is a simplified example
# In production, use proper sandboxing (e.g., docker, pypy sandbox)
try:
exec_globals = {}
exec(code, exec_globals)
for input_val, expected in test_cases:
result = exec_globals['solution'](input_val)
if result != expected:
return False
return True
except:
return False
# ==================== REWARD FUNCTION PRESETS ====================
# Preset for math/reasoning tasks
MATH_REASONING_REWARDS = [
incremental_format_reward,
soft_xml_format_reward,
exact_match_reward,
reasoning_quality_reward,
]
# Preset for code generation
CODE_GENERATION_REWARDS = [
code_block_format_reward,
code_execution_reward,
no_syntax_error_reward,
]
# Preset for summarization
SUMMARIZATION_REWARDS = [
ideal_length_reward,
fuzzy_match_reward,
no_repetition_penalty,
]
# Preset for Q&A
QA_REWARDS = [
exact_match_reward,
min_length_reward,
citation_reward,
]
@@ -0,0 +1,228 @@
"""
Basic GRPO Training Template
=============================
A minimal, production-ready template for GRPO training with TRL.
Adapt this for your specific task by modifying:
1. Dataset loading (get_dataset function)
2. Reward functions (reward_*_func)
3. System prompt (SYSTEM_PROMPT)
4. Hyperparameters (GRPOConfig)
"""
import torch
import re
from datasets import load_dataset, Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig
from trl import GRPOTrainer, GRPOConfig
# ==================== CONFIGURATION ====================
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
OUTPUT_DIR = "outputs/grpo-model"
MAX_PROMPT_LENGTH = 256
MAX_COMPLETION_LENGTH = 512
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
[Your step-by-step thinking]
</reasoning>
<answer>
[Final answer]
</answer>
"""
# ==================== DATASET ====================
def get_dataset(split="train"):
"""
Load and prepare your dataset.
Returns: Dataset with columns:
- 'prompt': List[Dict] with role/content
- 'answer': str (ground truth, optional)
"""
# Example: GSM8K math dataset
data = load_dataset('openai/gsm8k', 'main')[split]
def process_example(x):
# Extract ground truth answer
answer = x['answer'].split('####')[1].strip() if '####' in x['answer'] else None
return {
'prompt': [
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': x['question']}
],
'answer': answer
}
return data.map(process_example)
# ==================== HELPER FUNCTIONS ====================
def extract_xml_tag(text: str, tag: str) -> str:
"""Extract content between XML tags."""
pattern = f'<{tag}>(.*?)</{tag}>'
match = re.search(pattern, text, re.DOTALL)
return match.group(1).strip() if match else ""
def extract_answer(text: str) -> str:
"""Extract the final answer from structured output."""
return extract_xml_tag(text, 'answer')
# ==================== REWARD FUNCTIONS ====================
def correctness_reward_func(prompts, completions, answer, **kwargs):
"""
Reward correct answers.
Weight: 2.0 (highest priority)
"""
responses = [comp[0]['content'] for comp in completions]
extracted = [extract_answer(r) for r in responses]
return [2.0 if ans == gt else 0.0 for ans, gt in zip(extracted, answer)]
def format_reward_func(completions, **kwargs):
"""
Reward proper XML format.
Weight: 0.5
"""
pattern = r'<reasoning>.*?</reasoning>\s*<answer>.*?</answer>'
responses = [comp[0]['content'] for comp in completions]
return [0.5 if re.search(pattern, r, re.DOTALL) else 0.0 for r in responses]
def incremental_format_reward_func(completions, **kwargs):
"""
Incremental reward for partial format compliance.
Weight: up to 0.5
"""
responses = [comp[0]['content'] for comp in completions]
rewards = []
for r in responses:
score = 0.0
if '<reasoning>' in r:
score += 0.125
if '</reasoning>' in r:
score += 0.125
if '<answer>' in r:
score += 0.125
if '</answer>' in r:
score += 0.125
# Penalize extra content after closing tag
if '</answer>' in r:
extra = r.split('</answer>')[-1].strip()
score -= len(extra) * 0.001
rewards.append(score)
return rewards
# ==================== MODEL SETUP ====================
def setup_model_and_tokenizer():
"""Load model and tokenizer with optimizations."""
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer
def get_peft_config():
"""LoRA configuration for parameter-efficient training."""
return LoraConfig(
r=16,
lora_alpha=32,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"
],
task_type="CAUSAL_LM",
lora_dropout=0.05,
)
# ==================== TRAINING ====================
def main():
"""Main training function."""
# Load data
print("Loading dataset...")
dataset = get_dataset()
print(f"Dataset size: {len(dataset)}")
# Setup model
print("Loading model...")
model, tokenizer = setup_model_and_tokenizer()
# Training configuration
training_args = GRPOConfig(
output_dir=OUTPUT_DIR,
run_name="grpo-training",
# Learning rate
learning_rate=5e-6,
adam_beta1=0.9,
adam_beta2=0.99,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type='cosine',
# Batch settings
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
# GRPO specific
num_generations=8,
max_prompt_length=MAX_PROMPT_LENGTH,
max_completion_length=MAX_COMPLETION_LENGTH,
# Training duration
num_train_epochs=1,
# Optimization
bf16=True,
optim="adamw_8bit",
max_grad_norm=0.1,
# Logging
logging_steps=1,
save_steps=100,
report_to="wandb", # Change to "none" to disable logging
)
# Initialize trainer
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[
incremental_format_reward_func,
format_reward_func,
correctness_reward_func,
],
args=training_args,
train_dataset=dataset,
peft_config=get_peft_config(),
)
# Train
print("Starting training...")
trainer.train()
# Save final model
print(f"Saving model to {OUTPUT_DIR}/final")
trainer.save_model(f"{OUTPUT_DIR}/final")
print("Training complete!")
if __name__ == "__main__":
main()
@@ -0,0 +1,315 @@
---
name: miles-rl-training
description: Provides guidance for enterprise-grade RL training using miles, a production-ready fork of slime. Use when training large MoE models with FP8/INT4, needing train-inference alignment, or requiring speculative RL for maximum throughput.
version: 1.0.0
author: Orchestra Research
license: MIT
tags: [Reinforcement Learning, MoE, FP8, INT4, Enterprise, SGLang, Megatron-LM]
dependencies: [sglang-router>=0.2.3, ray, torch>=2.0.0, transformers>=4.40.0]
---
# miles: Enterprise-Grade RL for Large-Scale Model Training
miles is a high-performance, enterprise-ready RL framework optimized for large-scale model post-training. Built as a production fork of slime, it addresses critical challenges in MoE training stability, low-precision training, and train-inference alignment.
## When to Use miles
**Choose miles when you need:**
- Training 1TB+ MoE models (DeepSeek V3, Qwen3-MoE)
- FP8 or INT4 quantization-aware training
- Bit-wise identical train-inference alignment
- Speculative RL for maximum throughput
- Production stability with enterprise support
**Consider alternatives when:**
- You want the research-grade original → use **slime**
- You need flexible backend swapping → use **verl**
- You want PyTorch-native abstractions → use **torchforge**
## Key Features
### Low-Precision Training
- **Unified FP8**: End-to-end FP8 for both inference and training
- **INT4 QAT**: 1TB models on single-machine VRAM (H200)
- **Rollout Routing Replay (R3)**: Bit-wise expert alignment for MoE
### Performance Optimizations
- **Speculative RL**: 25%+ rollout speedup with online SFT draft models
- **Zero-Copy Weight Sync**: CUDA IPC zero-copy mapping
- **Partial Rollout**: Recycle half-finished trajectories
### Train-Inference Alignment
- **TIS/MIS**: Truncated/Masked Importance Sampling for off-policy correction
- **Kernel-level optimization**: FlashAttention-3, DeepGEMM integration
## Installation
```bash
# Recommended: Docker
docker pull radixark/miles:latest
docker run --rm --gpus all --ipc=host --shm-size=16g \
-it radixark/miles:latest /bin/bash
# From source
git clone https://github.com/radixark/miles.git
cd miles
pip install -r requirements.txt
pip install -e .
```
## Quick Start
miles inherits slime's configuration system. Basic training:
```bash
python train.py \
--advantage-estimator grpo \
--model-name qwen3-30b-a3b \
--hf-checkpoint /path/to/qwen3-30b-a3b-hf \
--rollout-batch-size 512 \
--n-samples-per-prompt 8
```
---
## Workflow 1: Large MoE Training
Use this workflow for training large MoE models like DeepSeek V3 or Qwen3-MoE.
### Prerequisites Checklist
- [ ] H100/H200 GPUs with FP8 support
- [ ] MoE model (DeepSeek V3, Qwen3-MoE)
- [ ] Docker environment with miles
### Step 1: Environment Setup
```bash
# FP8 block scaling (recommended for stability)
export NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1
export CUDA_DEVICE_MAX_CONNECTIONS=1
```
### Step 2: Configure Training
```bash
python train.py \
--actor-num-gpus-per-node 8 \
--rollout-num-gpus 8 \
--hf-checkpoint /path/to/deepseek-v3 \
--advantage-estimator grpo \
--tensor-model-parallel-size 8 \
--expert-model-parallel-size 4 \
--prompt-data /path/to/data.jsonl \
--num-rollout 3000
```
### Verification Checklist
- [ ] Model loads without errors
- [ ] Routing decisions are consistent
- [ ] No NaN/Inf in loss values
---
## Workflow 2: Speculative RL Training
Use this workflow for maximum rollout throughput with EAGLE speculative decoding.
### How Speculative RL Works
1. Small draft model generates candidate tokens
2. Target model verifies in parallel
3. Draft model updated via online SFT to track policy
### Step 1: Enable Speculative Decoding
miles supports EAGLE speculative decoding via SGLang:
```bash
python train.py \
--actor-num-gpus-per-node 8 \
--hf-checkpoint /path/to/target-model \
--sglang-speculative-algorithm EAGLE \
--sglang-speculative-num-steps 3 \
--sglang-speculative-eagle-topk 1 \
--sglang-speculative-num-draft-tokens 4 \
--sglang-speculative-draft-model-path /path/to/draft-model \
--advantage-estimator grpo \
--prompt-data /path/to/data.jsonl
```
### Step 2: Enable Online MTP Training (Optional)
For online SFT of draft model during training:
```bash
--mtp-num-layers 1 \
--enable-mtp-training \
--mtp-loss-scaling-factor 0.2
```
**Note**: Online MTP training requires a torch dist checkpoint with MTP weights. Add `--mtp-num-layers 1` during checkpoint conversion from HuggingFace.
### Expected Speedup
- **Standard rollout**: Baseline
- **Speculative RL**: 25-40% faster rollout
- **With partial rollout**: Additional 10-15% throughput
---
## Configuration Reference
miles inherits all slime arguments. See [slime API Reference](../slime/references/api-reference.md) for the complete list.
### Cluster Resources (from slime)
```bash
--actor-num-nodes 1
--actor-num-gpus-per-node 8
--rollout-num-gpus 8
--rollout-num-gpus-per-engine 2
--colocate
```
### Megatron Parallelism (from slime)
```bash
--tensor-model-parallel-size 8
--pipeline-model-parallel-size 2
--expert-model-parallel-size 4 # MoE expert parallelism
```
### Speculative Decoding (miles-specific)
```bash
--sglang-speculative-algorithm EAGLE
--sglang-speculative-num-steps 3
--sglang-speculative-eagle-topk 1
--sglang-speculative-num-draft-tokens 4
--sglang-enable-draft-weights-cpu-backup
--sglang-speculative-draft-model-path /your/draft/model/path
```
### Online MTP Training (miles-specific)
```bash
--mtp-num-layers 1
--enable-mtp-training
--mtp-loss-scaling-factor 0.2
```
---
## Key Features (Conceptual)
The following features are documented in miles but specific CLI flags may vary. Consult the miles repository for latest configuration.
### Unified FP8 Pipeline
End-to-end FP8 sampling and training that eliminates quantization-induced discrepancy causing RL collapse in MoE models.
### Rollout Routing Replay (R3)
Records expert routing decisions during SGLang inference and replays them during Megatron training for bit-wise expert alignment.
**How R3 Works**:
1. During SGLang inference, expert routing decisions are recorded
2. Routing decisions stored in `sample.rollout_routed_experts`
3. During Megatron training, routing is replayed instead of recomputed
4. Ensures identical expert selection between train and inference
### INT4 Quantization-Aware Training
Enables single-machine deployment of 1TB+ models (e.g., on H200).
**Memory Savings with INT4**:
| Model Size | BF16 VRAM | INT4 VRAM | Reduction |
|------------|-----------|-----------|-----------|
| 70B | 140GB | 45GB | 3.1x |
| 235B | 470GB | 150GB | 3.1x |
| 671B | 1.3TB | 420GB | 3.1x |
### Train-Inference Alignment
miles achieves "exactly 0 KL divergence" between training and inference through:
- Flash Attention 3
- DeepGEMM
- Batch-invariant kernels from Thinking Machines Lab
- `torch.compile` integration
---
## Sample Data Structure
miles uses the same `Sample` dataclass as slime with the `rollout_routed_experts` field for MoE routing replay:
```python
@dataclass
class Sample:
prompt: str | list[dict]
tokens: list[int]
response: str
reward: float | dict
loss_mask: list[int]
status: Status
metadata: dict
rollout_log_probs: list[float]
rollout_routed_experts: list[list[int]] # MoE routing for R3
```
See [slime API Reference](../slime/references/api-reference.md) for the complete Sample definition.
---
## Common Issues and Solutions
### Issue: FP8 Training Collapse
**Symptoms**: Loss explodes, NaN values
**Solutions**:
- Use block scaling: `export NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1`
- Reduce learning rate: `--lr 5e-7`
- Ensure MoE routing is consistent between train/inference
### Issue: Speculative Draft Drift
**Symptoms**: Low acceptance rate over time
**Solutions**:
- Enable online MTP training to keep draft model aligned
- Reduce speculative steps: `--sglang-speculative-num-steps 2`
- Use CPU backup: `--sglang-enable-draft-weights-cpu-backup`
### Issue: Train-Inference Mismatch
**Symptoms**: Policy divergence, reward collapse
**Solutions**:
- Use TIS for off-policy correction: `--use-tis --tis-threshold 0.9`
- Verify log probs match between SGLang and Megatron
- Enable R3 for MoE models
---
## Supported Models
| Family | Models | MoE Support |
|--------|--------|-------------|
| DeepSeek | R1, V3, V3.2 | Full |
| Qwen | 2, 2.5, 3 (including MoE) | Full |
| Llama | 3, 3.1, 3.3, 4 | Dense only |
| Gemma | 2, 3, 3N | Dense only |
| GLM | 4.5, 4.6, 4.7 | Dense only |
| MiniMax | M2, M2.1 | Full |
---
## Resources
- **GitHub**: https://github.com/radixark/miles
- **Introduction Blog**: https://lmsys.org/blog/2025-11-19-miles/
- **Slime (upstream)**: https://github.com/THUDM/slime
- **SGLang**: https://github.com/sgl-project/sglang
@@ -0,0 +1,141 @@
# miles API Reference
## Overview
miles is an enterprise-grade RL framework built on slime, adding advanced features for large-scale MoE training:
- Unified FP8 training and inference
- INT4 Quantization-Aware Training
- Rollout Routing Replay (R3)
- Speculative RL training
**Note**: miles inherits slime's configuration system. See [slime API Reference](../../slime/references/api-reference.md) for base arguments.
## Core Data Structures
miles uses the same `Sample` dataclass as slime with the `rollout_routed_experts` field for MoE routing replay.
## Quick Start
```bash
python train.py \
--advantage-estimator grpo \
--model-name qwen3-30b-a3b \
--hf-checkpoint /path/to/qwen3-30b-a3b-hf \
--rollout-batch-size 512 \
--n-samples-per-prompt 8
```
## Configuration Options
miles inherits slime's three argument categories (Megatron, SGLang with `--sglang-` prefix, and slime-specific). Key additions:
### Cluster Resources (inherited from slime)
```bash
--actor-num-nodes 1
--actor-num-gpus-per-node 8
--rollout-num-gpus 8
--rollout-num-gpus-per-engine 2
--colocate
```
### Megatron Parallelism (inherited from slime)
```bash
--tensor-model-parallel-size 8
--pipeline-model-parallel-size 2
--expert-model-parallel-size 4 # MoE expert parallelism
```
### Speculative Decoding
Verified flags from miles documentation:
```bash
# Basic speculative decoding
--sglang-speculative-algorithm EAGLE
--sglang-speculative-num-steps 3
--sglang-speculative-eagle-topk 1
--sglang-speculative-num-draft-tokens 4
--sglang-enable-draft-weights-cpu-backup
# Draft model path
--sglang-speculative-draft-model-path /your/draft/model/path
# Online SFT for draft model (MTP)
--mtp-num-layers 1
--enable-mtp-training
--mtp-loss-scaling-factor 0.2
```
**Note**: Online MTP training requires a torch dist checkpoint with MTP weights. Add `--mtp-num-layers 1` during checkpoint conversion from HuggingFace to torch dist format.
## Key Features (Conceptual)
The following features are documented in miles but specific CLI flags are not publicly documented. Consult the miles repository for latest configuration options.
### Unified FP8 Pipeline
End-to-end FP8 sampling and training that eliminates quantization-induced discrepancy causing RL collapse in MoE models.
### Rollout Routing Replay (R3)
Records expert routing decisions during SGLang inference and replays them during Megatron training for bit-wise expert alignment.
**How R3 Works**:
1. During SGLang inference, expert routing decisions are recorded
2. Routing decisions stored in `sample.rollout_routed_experts`
3. During Megatron training, routing is replayed instead of recomputed
4. Ensures identical expert selection between train and inference
### INT4 Quantization-Aware Training
Enables single-machine deployment of 1TB+ models (e.g., on H200).
**Memory Savings with INT4**:
| Model Size | BF16 VRAM | INT4 VRAM | Reduction |
|------------|-----------|-----------|-----------|
| 70B | 140GB | 45GB | 3.1x |
| 235B | 470GB | 150GB | 3.1x |
| 671B | 1.3TB | 420GB | 3.1x |
### Train-Inference Alignment
miles achieves "exactly 0 KL divergence" between training and inference through infrastructure optimizations:
- Flash Attention 3
- DeepGEMM
- Batch-invariant kernels from Thinking Machines Lab
- `torch.compile` integration
### Truncated/Masked Importance Sampling (TIS/MIS)
Algorithmic corrections for off-policy training. See slime documentation for `--use-tis` flag.
## Custom Functions
Same interface as slime:
```bash
--custom-generate-function-path generate.py
--custom-rm-path reward.py
```
## Supported Models
| Family | Models | MoE Support |
|--------|--------|-------------|
| DeepSeek | R1, V3, V3.2 | Full |
| Qwen | 2, 2.5, 3 (including MoE) | Full |
| Llama | 3, 3.1, 3.3, 4 | Dense only |
| Gemma | 2, 3, 3N | Dense only |
| GLM | 4.5, 4.6, 4.7 | Dense only |
| MiniMax | M2, M2.1 | Full |
## Resources
- GitHub: https://github.com/radixark/miles
- Introduction Blog: https://lmsys.org/blog/2025-11-19-miles/
- Slime (upstream): https://github.com/THUDM/slime
- SGLang: https://github.com/sgl-project/sglang
@@ -0,0 +1,352 @@
# miles Troubleshooting Guide
## FP8 Training Issues
### Issue: FP8 Training Collapse
**Symptoms**: Loss explodes, NaN values, reward collapses
**Solutions**:
1. **Use block scaling**:
```bash
--fp8-recipe blockwise
export NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1
```
2. **Enable R3 for MoE models**:
```bash
--use-r3
```
3. **Reduce learning rate**:
```bash
--lr 5e-7 # Reduce from 1e-6
```
4. **Warm up from BF16**:
```bash
--warmup-steps 100
--warmup-precision bf16
```
### Issue: FP8 vs BF16 Accuracy Gap
**Symptoms**: FP8 model underperforms BF16 baseline
**Solutions**:
1. **Use E4M3 format for activations**:
```bash
--fp8-format e4m3
```
2. **Enable dynamic scaling**:
```bash
--fp8-dynamic-scaling
```
3. **Skip sensitive layers**:
```bash
--fp8-skip-layers "lm_head,embed"
```
## Train-Inference Mismatch Issues
### Issue: Policy Divergence
**Symptoms**: Model behavior differs between training and inference
**Solutions**:
1. **Enable Rollout Routing Replay**:
```bash
--use-r3
```
2. **Use importance sampling correction**:
```bash
--use-tis --tis-threshold 0.9
```
3. **Verify log probs match**:
```bash
--verify-logprobs
```
### Issue: Expert Routing Mismatch (MoE)
**Symptoms**: Different experts activated during train vs inference
**Solutions**:
1. **Enable R3**:
```bash
--use-r3
--r3-buffer-size 1000
```
2. **Use deterministic routing**:
```bash
--deterministic-expert-routing
```
## INT4 Training Issues
### Issue: INT4 Accuracy Degradation
**Symptoms**: Worse performance than BF16 or FP8
**Solutions**:
1. **Increase group size**:
```bash
--int4-group-size 256 # Increase from 128
```
2. **Use mixed precision for sensitive layers**:
```bash
--int4-skip-layers "lm_head,embed,layer_norm"
```
3. **Warm start from BF16**:
```bash
--warmup-steps 100
--warmup-precision bf16
```
4. **Increase learning rate** (INT4 often needs higher LR):
```bash
--lr 2e-6 # Increase from 1e-6
```
### Issue: INT4 OOM Despite Expected Savings
**Symptoms**: Still running out of memory with INT4
**Solutions**:
1. **Verify environment variable**:
```bash
export OPEN_TRAINING_INT4_FAKE_QAT_FLAG=1
```
2. **Check group size alignment**:
```bash
# Group size must divide hidden dimension evenly
--int4-group-size 128 # Must divide hidden_size
```
## Speculative RL Issues
### Issue: Low Acceptance Rate
**Symptoms**: Draft model tokens frequently rejected
**Solutions**:
1. **Reduce lookahead**:
```bash
--spec-lookahead 3 # Reduce from 5
```
2. **Update draft more frequently**:
```bash
--online-sft-interval 5 # Reduce from 10
```
3. **Increase draft learning rate**:
```bash
--draft-lr 1e-5 # Increase
```
### Issue: Draft Model Drift
**Symptoms**: Acceptance rate drops over time
**Solutions**:
1. **Enable online SFT**:
```bash
--online-sft-interval 5
```
2. **Use EMA for draft updates**:
```bash
--draft-ema-decay 0.99
```
3. **Reinitialize draft periodically**:
```bash
--reinit-draft-interval 1000
```
### Issue: Speculative Training Slower Than Expected
**Symptoms**: Not achieving expected 25%+ speedup
**Solutions**:
1. **Verify draft model is small enough**:
```bash
# Draft should be 1/4 to 1/10 size of target
```
2. **Check lookahead is optimal**:
```bash
--spec-lookahead 5 # Sweet spot for most models
```
3. **Profile to find bottleneck**:
```bash
--profile-speculative
```
## Weight Synchronization Issues
### Issue: Zero-Copy Sync Failures
**Symptoms**: Errors with CUDA IPC, weight corruption
**Solutions**:
1. **Verify CUDA IPC support**:
```bash
nvidia-smi topo -m # Check GPU topology
```
2. **Fall back to standard sync**:
```bash
# Remove --use-zero-copy-sync
```
3. **Increase bucket size**:
```bash
--sync-bucket-size 2147483648 # 2GB
```
### Issue: Slow Weight Sync Despite Zero-Copy
**Symptoms**: Weight sync still slow
**Solutions**:
1. **Use colocated mode**:
```bash
--colocate
```
2. **Enable async weight transfer**:
```bash
--async-weight-sync
```
## MoE-Specific Issues
### Issue: Expert Load Imbalance
**Symptoms**: Some experts heavily loaded, others unused
**Solutions**:
1. **Enable load balancing loss**:
```bash
--aux-loss-coef 0.01
```
2. **Use capacity factor**:
```bash
--moe-capacity-factor 1.25
```
### Issue: Expert Parallelism OOM
**Symptoms**: OOM with large MoE models
**Solutions**:
1. **Increase expert parallelism**:
```bash
--expert-model-parallel-size 8 # Increase from 4
```
2. **Reduce batch size per GPU**:
```bash
--micro-batch-size 1
```
3. **Enable expert offloading**:
```bash
--offload-experts
```
## Multi-Agent Issues
### Issue: Co-Evolution Instability
**Symptoms**: Agents oscillate or one dominates
**Solutions**:
1. **Use alternating updates**:
```yaml
co_evolution:
strategy: alternating
```
2. **Reduce co-evolution frequency**:
```bash
--co-evolution-interval 20 # Increase from 10
```
3. **Add population diversity**:
```yaml
co_evolution:
population_size: 4
```
## Debugging Tips
### Enable Verbose Logging
```bash
--log-level DEBUG
export MILES_DEBUG=1
```
### Check FP8 Tensors
```python
# Verify FP8 is active
for name, param in model.named_parameters():
print(f"{name}: {param.dtype}")
```
### Profile Training
```bash
--profile
--profile-dir /path/to/profile
```
### Verify R3 Is Working
```python
# Check routing is being recorded
sample = samples[0]
assert sample.rollout_routed_experts is not None
assert len(sample.rollout_routed_experts) > 0
```
### Monitor GPU Memory
```bash
watch -n 1 nvidia-smi
```
## Resources
- GitHub Issues: https://github.com/radixark/miles/issues
- Unified FP8 Blog: https://lmsys.org/blog/2025-11-25-fp8-rl/
- Train-Inference Mismatch Tutorial: https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/slime/mismatch/blog-en.md
- SGLang Discord: Community support
@@ -0,0 +1,249 @@
---
name: openrlhf-training
description: High-performance RLHF framework with Ray+vLLM acceleration. Use for PPO, GRPO, RLOO, DPO training of large models (7B-70B+). Built on Ray, vLLM, ZeRO-3. 2× faster than DeepSpeedChat with distributed architecture and GPU resource sharing.
version: 1.0.0
author: Orchestra Research
license: MIT
tags: [Post-Training, OpenRLHF, RLHF, PPO, GRPO, RLOO, DPO, Ray, vLLM, Distributed Training, Large Models, ZeRO-3]
dependencies: [openrlhf, ray, vllm, torch, transformers, deepspeed]
---
# OpenRLHF - High-Performance RLHF Training
## Quick start
OpenRLHF is a Ray-based RLHF framework optimized for distributed training with vLLM inference acceleration.
**Installation**:
```bash
# Launch Docker container
docker run --runtime=nvidia -it --rm --shm-size="10g" --cap-add=SYS_ADMIN \
-v $PWD:/openrlhf nvcr.io/nvidia/pytorch:25.02-py3 bash
# Uninstall conflicts
sudo pip uninstall xgboost transformer_engine flash_attn pynvml -y
# Install OpenRLHF with vLLM
pip install openrlhf[vllm]
```
**PPO Training** (Hybrid Engine):
```bash
ray start --head --node-ip-address 0.0.0.0 --num-gpus 8
ray job submit --address="http://127.0.0.1:8265" \
--runtime-env-json='{"working_dir": "/openrlhf"}' \
-- python3 -m openrlhf.cli.train_ppo_ray \
--ref_num_nodes 1 --ref_num_gpus_per_node 8 \
--reward_num_nodes 1 --reward_num_gpus_per_node 8 \
--critic_num_nodes 1 --critic_num_gpus_per_node 8 \
--actor_num_nodes 1 --actor_num_gpus_per_node 8 \
--vllm_num_engines 4 --vllm_tensor_parallel_size 2 \
--colocate_all_models \
--vllm_gpu_memory_utilization 0.5 \
--pretrain OpenRLHF/Llama-3-8b-sft-mixture \
--reward_pretrain OpenRLHF/Llama-3-8b-rm-700k \
--save_path ./output/llama3-8b-rlhf \
--micro_train_batch_size 8 --train_batch_size 128 \
--micro_rollout_batch_size 16 --rollout_batch_size 1024 \
--max_epochs 1 --prompt_max_len 1024 --generate_max_len 1024 \
--zero_stage 3 --bf16 \
--actor_learning_rate 5e-7 --critic_learning_rate 9e-6 \
--init_kl_coef 0.01 --normalize_reward \
--gradient_checkpointing --packing_samples \
--vllm_enable_sleep --deepspeed_enable_sleep
```
**GRPO Training** (Group Normalized Policy Optimization):
```bash
# Same command as PPO, but add:
--advantage_estimator group_norm
```
## Common workflows
### Workflow 1: Full RLHF pipeline (SFT → Reward Model → PPO)
**Step 1: Train reward model** (DPO):
```bash
deepspeed --module openrlhf.cli.train_rm \
--save_path ./output/llama3-8b-rm \
--save_steps -1 --logging_steps 1 \
--eval_steps -1 --train_batch_size 256 \
--micro_train_batch_size 1 --pretrain meta-llama/Meta-Llama-3-8B \
--bf16 --max_epochs 1 --max_len 8192 \
--zero_stage 3 --learning_rate 9e-6 \
--dataset OpenRLHF/preference_dataset_mixture2_and_safe_pku \
--apply_chat_template --chosen_key chosen \
--rejected_key rejected --flash_attn --gradient_checkpointing
```
**Step 2: PPO training**:
```bash
ray start --head --node-ip-address 0.0.0.0 --num-gpus 8
ray job submit --address="http://127.0.0.1:8265" \
-- python3 -m openrlhf.cli.train_ppo_ray \
--ref_num_nodes 1 --ref_num_gpus_per_node 8 \
--reward_num_nodes 1 --reward_num_gpus_per_node 8 \
--critic_num_nodes 1 --critic_num_gpus_per_node 8 \
--actor_num_nodes 1 --actor_num_gpus_per_node 8 \
--vllm_num_engines 4 --vllm_tensor_parallel_size 2 \
--colocate_all_models \
--pretrain OpenRLHF/Llama-3-8b-sft-mixture \
--reward_pretrain ./output/llama3-8b-rm \
--save_path ./output/llama3-8b-ppo \
--micro_train_batch_size 8 --train_batch_size 128 \
--micro_rollout_batch_size 16 --rollout_batch_size 1024 \
--max_epochs 1 --prompt_max_len 1024 --generate_max_len 1024 \
--zero_stage 3 --bf16 \
--actor_learning_rate 5e-7 --critic_learning_rate 9e-6 \
--init_kl_coef 0.01 --normalize_reward \
--vllm_enable_sleep --deepspeed_enable_sleep
```
### Workflow 2: GRPO training (no critic model needed)
Memory-efficient alternative to PPO:
```bash
ray job submit --address="http://127.0.0.1:8265" \
-- python3 -m openrlhf.cli.train_ppo_ray \
--advantage_estimator group_norm \
--ref_num_nodes 1 --ref_num_gpus_per_node 8 \
--reward_num_nodes 1 --reward_num_gpus_per_node 8 \
--actor_num_nodes 1 --actor_num_gpus_per_node 8 \
--vllm_num_engines 4 --vllm_tensor_parallel_size 2 \
--colocate_all_models \
--pretrain OpenRLHF/Llama-3-8b-sft-mixture \
--reward_pretrain OpenRLHF/Llama-3-8b-rm-700k \
--save_path ./output/llama3-8b-grpo \
--micro_train_batch_size 8 --train_batch_size 128 \
--micro_rollout_batch_size 16 --rollout_batch_size 1024 \
--max_epochs 1 --bf16 \
--actor_learning_rate 5e-7 \
--init_kl_coef 0.01 --use_kl_loss --kl_estimator k3 \
--normalize_reward --no_advantage_std_norm
```
**Key GRPO parameters**:
- `--advantage_estimator group_norm` - Enables GRPO
- `--use_kl_loss` - KL loss from GRPO paper
- `--kl_estimator k3` - Loss function (k2 ≈ k1)
- `--no_advantage_std_norm` - Disables std normalization
### Workflow 3: DPO training (preference optimization)
Simpler alternative without reward model:
```bash
deepspeed --module openrlhf.cli.train_dpo \
--save_path ./output/llama3-8b-dpo \
--save_steps -1 --logging_steps 1 \
--eval_steps -1 --train_batch_size 256 \
--micro_train_batch_size 2 --pretrain meta-llama/Meta-Llama-3-8B \
--bf16 --max_epochs 1 --max_len 8192 \
--zero_stage 3 --learning_rate 5e-7 --beta 0.1 \
--dataset OpenRLHF/preference_dataset_mixture2_and_safe_pku \
--apply_chat_template --chosen_key chosen \
--rejected_key rejected --flash_attn --gradient_checkpointing
```
## When to use vs alternatives
**Use OpenRLHF when**:
- Training large models (7B-70B+) with RL
- Need vLLM inference acceleration
- Want distributed architecture with Ray
- Have multi-node GPU cluster
- Need PPO/GRPO/RLOO/DPO in one framework
**Algorithm selection**:
- **PPO**: Maximum control, best for complex rewards
- **GRPO**: Memory-efficient, no critic needed
- **RLOO**: Modified PPO with per-token KL
- **REINFORCE++**: More stable than GRPO, faster than PPO
- **DPO**: Simplest, no reward model needed
**Use alternatives instead**:
- **TRL**: Single-node training, simpler API
- **veRL**: ByteDance's framework for 671B models
- **DeepSpeedChat**: Integrated with DeepSpeed ecosystem
## Common issues
**Issue: GPU OOM with large models**
Disable model colocation:
```bash
# Remove --colocate_all_models flag
# Allocate separate GPUs for each model
--actor_num_gpus_per_node 8 \
--critic_num_gpus_per_node 8 \
--reward_num_gpus_per_node 8 \
--ref_num_gpus_per_node 8
```
**Issue: DeepSpeed GPU index out of range**
Set environment variable:
```bash
export RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=1
```
**Issue: Training instability**
Use Hybrid Engine instead of async:
```bash
--colocate_all_models \
--vllm_enable_sleep \
--deepspeed_enable_sleep
```
Adjust KL coefficient:
```bash
--init_kl_coef 0.05 # Increase from 0.01
```
**Issue: Slow generation during PPO**
Enable vLLM acceleration:
```bash
--vllm_num_engines 4 \
--vllm_tensor_parallel_size 2 \
--vllm_gpu_memory_utilization 0.5
```
## Advanced topics
**Hybrid Engine GPU sharing**: See [references/hybrid-engine.md](references/hybrid-engine.md) for vLLM sleep mode, DeepSpeed sleep mode, and optimal node allocation.
**Algorithm comparison**: See [references/algorithm-comparison.md](references/algorithm-comparison.md) for PPO vs GRPO vs RLOO vs REINFORCE++ benchmarks and hyperparameters.
**Multi-node setup**: See [references/multi-node-training.md](references/multi-node-training.md) for Ray cluster configuration and fault tolerance.
**Custom reward functions**: See [references/custom-rewards.md](references/custom-rewards.md) for reinforced fine-tuning and agent RLHF.
## Hardware requirements
- **GPU**: NVIDIA A100/H100 recommended
- **VRAM**:
- 7B model: 8× A100 40GB (Hybrid Engine)
- 70B model: 48× A100 80GB (vLLM:Actor:Critic = 1:1:1)
- **Multi-node**: Ray cluster with InfiniBand recommended
- **Docker**: NVIDIA PyTorch container 25.02+
**Performance**:
- 2× faster than DeepSpeedChat
- vLLM inference acceleration
- Hybrid Engine minimizes GPU idle time
## Resources
- Docs: https://github.com/OpenRLHF/OpenRLHF
- Paper: https://arxiv.org/abs/2405.11143
- Examples: https://github.com/OpenRLHF/OpenRLHF/tree/main/examples
- Discord: Community support
@@ -0,0 +1,404 @@
# Algorithm Comparison
Complete guide to RL algorithms in OpenRLHF: PPO, REINFORCE++, GRPO, RLOO, and their variants.
## Overview
OpenRLHF supports 6 RL algorithms selectable via `--advantage_estimator`:
- **gae** - PPO with Generalized Advantage Estimation
- **reinforce** - REINFORCE++ (PPO optimizations without critic)
- **reinforce_baseline** - REINFORCE++ with baseline
- **group_norm** - GRPO (Group Normalized Policy Optimization)
- **dr_grpo** - Dr. GRPO (GRPO without std normalization)
- **rloo** - Reinforcement Learning with Online Off-policy Correction
## Algorithm Details
### PPO (Proximal Policy Optimization)
**Formula**:
```
loss = -min(ratio * advantages, clip(ratio, 1-ε, 1+ε) * advantages)
ratio = π_new(a|s) / π_old(a|s)
```
**Characteristics**:
- **Stability**: High (clipped objective prevents large updates)
- **Memory**: High (stores actor + critic experiences)
- **Speed**: Medium (critic training overhead)
- **Requires**: Critic network for value estimation
**Implementation**:
```python
surr1 = ratio * advantages
surr2 = ratio.clamp(1 - clip_eps_low, 1 + clip_eps_high) * advantages
loss = -torch.min(surr1, surr2)
```
**When to use**:
- General-purpose RLHF
- Complex reward functions
- Need stable training
**Hyperparameters**:
```bash
--advantage_estimator gae # Enable PPO
--clip_eps_low 0.2 # Clipping lower bound
--clip_eps_high 0.2 # Clipping upper bound
--actor_learning_rate 1e-6
--critic_learning_rate 9e-6
--init_kl_coef 0.01
```
### REINFORCE++
**Formula**:
```
loss = -ratio * advantages (with PPO-clip)
advantages = cumulative_returns - baseline
```
**Characteristics**:
- **Stability**: Higher than GRPO
- **Memory**: Lower (no critic network)
- **Speed**: Faster than PPO
- **Requires**: No critic network
**Key innovation**: Integrates PPO optimizations (advantage normalization, PPO-clip loss) into REINFORCE while eliminating critic network overhead.
**When to use**:
- Want PPO stability without critic
- Limited memory budget
- Fast training priority
**Hyperparameters**:
```bash
--advantage_estimator reinforce
--critic_pretrain None # No critic needed
--init_kl_coef 0.01
--actor_learning_rate 1e-6
```
### REINFORCE++-baseline
**Formula**:
```
rewards = rewards - mean(rewards_same_prompt)
```
**Characteristics**:
- **Stability**: Very high
- **Memory**: Lower (no critic)
- **Speed**: Faster than PPO
- **Requires**: Multiple samples per prompt
**Key innovation**: Uses mean reward of multiple samples from same prompt as baseline to reshape rewards.
**When to use**:
- RLVR (Reinforcement Learning via Verifier Rewards) settings
- Reward patterns vary (0/1/-0.5)
- Multiple samples per prompt available
**Hyperparameters**:
```bash
--advantage_estimator reinforce_baseline
--n_samples_per_prompt 4 # Must be > 1
--init_kl_coef 0.01
```
### GRPO (Group Normalized Policy Optimization)
**Formula**:
```
rewards = (rewards - mean(rewards)) / (std(rewards) + 1e-9)
loss = -ratio * normalized_advantages
KL loss (optional): k1, k2, or k3 estimator
```
**Characteristics**:
- **Stability**: Lower than REINFORCE++
- **Memory**: Lower (no critic)
- **Speed**: Fast
- **Requires**: Group reward normalization
**Key innovation**: Group-based advantage normalization with optional KL loss.
**When to use**:
- Exploring policy optimization variants
- Need reward normalization
- Memory-constrained
**Hyperparameters**:
```bash
--advantage_estimator group_norm
--use_kl_loss # Enable KL loss
--kl_estimator k3 # k3 for loss, k2 ≈ k1
--init_kl_coef 0.01
--no_advantage_std_norm # Optional: disable std norm
```
**KL estimator variance**:
- **k3**: Larger variance under categorical distribution
- **k1, k2**: Similar variance, k2 ≈ k1 for loss
### Dr. GRPO
**Formula**:
```
rewards = rewards - mean(rewards) # No std normalization
```
**Characteristics**:
- **Stability**: Similar to GRPO
- **Memory**: Lower (no critic)
- **Speed**: Fast
- **Requires**: Group mean normalization only
**Key innovation**: Removes local group normalization `/std` from GRPO (not needed in RL variance reduction theory).
**When to use**:
- GRPO variant experimentation
- Avoid std normalization issues
**Hyperparameters**:
```bash
--advantage_estimator dr_grpo
--init_kl_coef 0.01
```
### RLOO (RL with Online Off-policy Correction)
**Formula**:
```
baseline = (sum(rewards) - rewards) / (n_samples - 1)
rewards = rewards - baseline
loss = -ratio * advantages (with PPO-clip)
```
**Characteristics**:
- **Stability**: High (PPO-clip)
- **Memory**: Lower (no critic)
- **Speed**: Fast
- **Requires**: Multiple samples per prompt, per-token KL
**Key innovation**: Incorporates per-token KL reward and PPO-clip loss.
**When to use**:
- Need per-token KL rewards
- Want PPO stability without critic
- Multiple samples per prompt
**Hyperparameters**:
```bash
--advantage_estimator rloo
--n_samples_per_prompt 4 # Must be > 1
--init_kl_coef 0.01
```
## Comparison Table
| Algorithm | Critic | Stability | Memory | Speed | Best For |
|-----------|--------|-----------|--------|-------|----------|
| PPO | ✅ Yes | ⭐⭐⭐⭐⭐ | High | Medium | General purpose |
| REINFORCE++ | ❌ No | ⭐⭐⭐⭐ | Low | **Fast** | Critic-free PPO |
| REINFORCE++-baseline | ❌ No | ⭐⭐⭐⭐⭐ | Low | **Fast** | RLVR settings |
| GRPO | ❌ No | ⭐⭐⭐ | Low | Fast | Reward normalization |
| Dr. GRPO | ❌ No | ⭐⭐⭐ | Low | Fast | GRPO variant |
| RLOO | ❌ No | ⭐⭐⭐⭐ | Low | Fast | Per-token KL |
## Experience Data Structure
**PPO (with critic)**:
```python
@dataclass
class Experience:
sequences: torch.Tensor # Token sequences
attention_mask: torch.Tensor # Attention masks
action_mask: torch.Tensor # Action masks
action_log_probs: torch.Tensor # Log π(a|s)
values: torch.Tensor # Critic value estimates
returns: torch.Tensor # Cumulative returns
advantages: torch.Tensor # GAE advantages
reward: float # Total reward
kl: torch.Tensor # KL divergence
```
**REINFORCE++ (no critic)**:
```python
# No values, returns, or advantages stored
# Only sequences, log_probs, and rewards
```
## Memory Comparison (7B Model)
| Algorithm | Components | Memory (8× A100) |
|-----------|-----------|------------------|
| PPO | Actor + Critic + Reward + Ref | ~40GB |
| REINFORCE++ | Actor + Reward + Ref | ~28GB |
| GRPO | Actor + Reward + Ref | ~28GB |
| RLOO | Actor + Reward + Ref | ~28GB |
**Savings**: ~30% memory reduction without critic
## Speed Comparison
**Relative training time** (7B model, 1000 steps):
- PPO: 1.0× baseline
- REINFORCE++: **0.75×** (25% faster)
- GRPO: 0.80×
- RLOO: 0.80×
**Why REINFORCE++ is faster**:
- No critic training
- No value function updates
- Fewer backward passes
## Choosing an Algorithm
### Decision Tree
```
Need maximum stability?
├─ Yes → PPO (with critic)
└─ No ↓
Have multiple samples per prompt?
├─ Yes ↓
│ └─ RLVR setting with varying rewards?
│ ├─ Yes → REINFORCE++-baseline
│ └─ No → RLOO (if need per-token KL)
└─ No ↓
Want faster than PPO?
└─ Yes → REINFORCE++ (most stable critic-free)
Experimenting with normalization?
└─ Yes → GRPO or Dr. GRPO
```
### By Use Case
**Production deployment**:
```bash
# Maximum stability
--advantage_estimator gae # PPO
--clip_eps_low 0.2
--init_kl_coef 0.01
```
**Memory-constrained**:
```bash
# No critic, stable
--advantage_estimator reinforce # REINFORCE++
--critic_pretrain None
```
**RLVR / Verification rewards**:
```bash
# Baseline reward shaping
--advantage_estimator reinforce_baseline
--n_samples_per_prompt 4
```
**Research / Experimentation**:
```bash
# Explore GRPO variants
--advantage_estimator group_norm
--use_kl_loss --kl_estimator k3
```
## Advanced Configuration
### Reward Normalization
**PPO (no manual normalization)**:
```bash
--advantage_estimator gae
# GAE handles advantage normalization
```
**GRPO (group normalization)**:
```bash
--advantage_estimator group_norm
--normalize_reward # Optional additional normalization
```
**Disable std normalization**:
```bash
--no_advantage_std_norm # Keep mean norm only
```
### KL Penalty Configuration
**All algorithms support**:
```bash
--init_kl_coef 0.01 # Initial KL coefficient
--kl_target 0.1 # Target KL divergence
--kl_horizon 10000 # Steps to reach target
```
**GRPO-specific**:
```bash
--use_kl_loss # Enable KL loss term
--kl_estimator k3 # Loss function choice
```
### Clipping Configuration
**PPO clipping**:
```bash
--clip_eps_low 0.2 # Lower bound
--clip_eps_high 0.2 # Upper bound
```
**Reward clipping**:
```bash
--reward_clip_range 10.0 # Clip rewards to [-10, 10]
```
## Common Issues
### PPO Instability
**Symptom**: Large policy updates, divergence
**Solution**: Reduce clipping range
```bash
--clip_eps_low 0.1 # Reduce from 0.2
--clip_eps_high 0.1
```
### GRPO High Variance
**Symptom**: Unstable training with GRPO
**Solution**: Switch to REINFORCE++
```bash
--advantage_estimator reinforce # More stable
```
### Memory OOM with PPO
**Symptom**: OOM during critic training
**Solution**: Switch to critic-free
```bash
--advantage_estimator reinforce # No critic
--critic_pretrain None
```
### RLOO/Baseline Requires Multiple Samples
**Symptom**: `AssertionError: n_samples_per_prompt must be > 1`
**Solution**:
```bash
--n_samples_per_prompt 4 # Minimum 2, recommended 4-8
```
## References
- PPO paper: https://arxiv.org/abs/1707.06347
- GRPO paper: https://arxiv.org/abs/2402.03300
- OpenRLHF: https://github.com/OpenRLHF/OpenRLHF
- OpenRLHF paper: https://arxiv.org/abs/2405.11143

Some files were not shown because too many files have changed in this diff Show More