fix: dereference orchestra-skills submodule, add as plain files
This commit is contained in:
BIN
Binary file not shown.
BIN
Binary file not shown.
Submodule orchestra-skills deleted from 28f2d29236
@@ -0,0 +1,293 @@
|
||||
{
|
||||
"name": "ai-research-skills",
|
||||
"owner": {
|
||||
"name": "Orchestra Research",
|
||||
"email": "zechen@orchestra-research.com"
|
||||
},
|
||||
"metadata": {
|
||||
"description": "Comprehensive library of 98 AI research engineering skills enabling autonomous AI research from hypothesis to experimental verification",
|
||||
"version": "1.2.0"
|
||||
},
|
||||
"plugins": [
|
||||
{
|
||||
"name": "model-architecture",
|
||||
"description": "LLM architectures and implementations including LitGPT, Mamba, NanoGPT, RWKV, and TorchTitan. Use when implementing, training, or understanding transformer and alternative architectures.",
|
||||
"source": "./",
|
||||
"strict": false,
|
||||
"skills": [
|
||||
"./01-model-architecture/litgpt",
|
||||
"./01-model-architecture/mamba",
|
||||
"./01-model-architecture/nanogpt",
|
||||
"./01-model-architecture/rwkv",
|
||||
"./01-model-architecture/torchtitan"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "tokenization",
|
||||
"description": "Text tokenization for LLMs including HuggingFace Tokenizers and SentencePiece. Use when training custom tokenizers or handling multilingual text.",
|
||||
"source": "./",
|
||||
"strict": false,
|
||||
"skills": [
|
||||
"./02-tokenization/huggingface-tokenizers",
|
||||
"./02-tokenization/sentencepiece"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "fine-tuning",
|
||||
"description": "LLM fine-tuning frameworks including Axolotl, LLaMA-Factory, PEFT, and Unsloth. Use when fine-tuning models with LoRA, QLoRA, or full fine-tuning.",
|
||||
"source": "./",
|
||||
"strict": false,
|
||||
"skills": [
|
||||
"./03-fine-tuning/axolotl",
|
||||
"./03-fine-tuning/llama-factory",
|
||||
"./03-fine-tuning/peft",
|
||||
"./03-fine-tuning/unsloth"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "mechanistic-interpretability",
|
||||
"description": "Neural network interpretability tools including TransformerLens, SAELens, NNSight, and pyvene. Use when analyzing model internals, finding circuits, or understanding how models compute.",
|
||||
"source": "./",
|
||||
"strict": false,
|
||||
"skills": [
|
||||
"./04-mechanistic-interpretability/nnsight",
|
||||
"./04-mechanistic-interpretability/pyvene",
|
||||
"./04-mechanistic-interpretability/saelens",
|
||||
"./04-mechanistic-interpretability/transformer-lens"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "data-processing",
|
||||
"description": "Data curation and processing at scale including NeMo Curator and Ray Data. Use when preparing training datasets or processing large-scale data.",
|
||||
"source": "./",
|
||||
"strict": false,
|
||||
"skills": [
|
||||
"./05-data-processing/nemo-curator",
|
||||
"./05-data-processing/ray-data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "post-training",
|
||||
"description": "RLHF and preference alignment including TRL, GRPO, OpenRLHF, SimPO, verl, slime, miles, and torchforge. Use when aligning models with human preferences, training reward models, or large-scale RL training.",
|
||||
"source": "./",
|
||||
"strict": false,
|
||||
"skills": [
|
||||
"./06-post-training/grpo-rl-training",
|
||||
"./06-post-training/miles",
|
||||
"./06-post-training/openrlhf",
|
||||
"./06-post-training/simpo",
|
||||
"./06-post-training/slime",
|
||||
"./06-post-training/torchforge",
|
||||
"./06-post-training/trl-fine-tuning",
|
||||
"./06-post-training/verl"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "safety-alignment",
|
||||
"description": "AI safety and content moderation including Constitutional AI, LlamaGuard, NeMo Guardrails, and Prompt Guard. Use when implementing safety filters, content moderation, or prompt injection detection.",
|
||||
"source": "./",
|
||||
"strict": false,
|
||||
"skills": [
|
||||
"./07-safety-alignment/constitutional-ai",
|
||||
"./07-safety-alignment/llamaguard",
|
||||
"./07-safety-alignment/nemo-guardrails",
|
||||
"./07-safety-alignment/prompt-guard"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "distributed-training",
|
||||
"description": "Multi-GPU and multi-node training including DeepSpeed, PyTorch FSDP, Accelerate, Megatron-Core, PyTorch Lightning, and Ray Train. Use when training large models across GPUs.",
|
||||
"source": "./",
|
||||
"strict": false,
|
||||
"skills": [
|
||||
"./08-distributed-training/accelerate",
|
||||
"./08-distributed-training/deepspeed",
|
||||
"./08-distributed-training/megatron-core",
|
||||
"./08-distributed-training/pytorch-fsdp2",
|
||||
"./08-distributed-training/pytorch-lightning",
|
||||
"./08-distributed-training/ray-train"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "infrastructure",
|
||||
"description": "GPU cloud and compute orchestration including Modal, Lambda Labs, and SkyPilot. Use when deploying training jobs or managing GPU resources.",
|
||||
"source": "./",
|
||||
"strict": false,
|
||||
"skills": [
|
||||
"./09-infrastructure/lambda-labs",
|
||||
"./09-infrastructure/modal",
|
||||
"./09-infrastructure/skypilot"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "optimization",
|
||||
"description": "Model optimization and quantization including Flash Attention, bitsandbytes, GPTQ, AWQ, GGUF, and HQQ. Use when reducing memory, accelerating inference, or quantizing models.",
|
||||
"source": "./",
|
||||
"strict": false,
|
||||
"skills": [
|
||||
"./10-optimization/awq",
|
||||
"./10-optimization/bitsandbytes",
|
||||
"./10-optimization/flash-attention",
|
||||
"./10-optimization/gguf",
|
||||
"./10-optimization/gptq",
|
||||
"./10-optimization/hqq",
|
||||
"./10-optimization/ml-training-recipes"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "evaluation",
|
||||
"description": "LLM benchmarking and evaluation including lm-evaluation-harness, BigCode Evaluation Harness, and NeMo Evaluator. Use when benchmarking models or measuring performance.",
|
||||
"source": "./",
|
||||
"strict": false,
|
||||
"skills": [
|
||||
"./11-evaluation/bigcode-evaluation-harness",
|
||||
"./11-evaluation/lm-evaluation-harness",
|
||||
"./11-evaluation/nemo-evaluator"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "inference-serving",
|
||||
"description": "Production LLM inference including vLLM, TensorRT-LLM, llama.cpp, and SGLang. Use when deploying models for production inference.",
|
||||
"source": "./",
|
||||
"strict": false,
|
||||
"skills": [
|
||||
"./12-inference-serving/llama-cpp",
|
||||
"./12-inference-serving/sglang",
|
||||
"./12-inference-serving/tensorrt-llm",
|
||||
"./12-inference-serving/vllm"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "mlops",
|
||||
"description": "ML experiment tracking and lifecycle including Weights & Biases, MLflow, and TensorBoard. Use when tracking experiments or managing models.",
|
||||
"source": "./",
|
||||
"strict": false,
|
||||
"skills": [
|
||||
"./13-mlops/mlflow",
|
||||
"./13-mlops/tensorboard",
|
||||
"./13-mlops/weights-and-biases"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "agents",
|
||||
"description": "LLM agent frameworks including LangChain, LlamaIndex, CrewAI, and AutoGPT. Use when building chatbots, autonomous agents, or tool-using systems.",
|
||||
"source": "./",
|
||||
"strict": false,
|
||||
"skills": [
|
||||
"./14-agents/autogpt",
|
||||
"./14-agents/crewai",
|
||||
"./14-agents/langchain",
|
||||
"./14-agents/llamaindex"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "rag",
|
||||
"description": "Retrieval-Augmented Generation including Chroma, FAISS, Pinecone, Qdrant, and Sentence Transformers. Use when building semantic search or document retrieval systems.",
|
||||
"source": "./",
|
||||
"strict": false,
|
||||
"skills": [
|
||||
"./15-rag/chroma",
|
||||
"./15-rag/faiss",
|
||||
"./15-rag/pinecone",
|
||||
"./15-rag/qdrant",
|
||||
"./15-rag/sentence-transformers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "prompt-engineering",
|
||||
"description": "Structured LLM outputs including DSPy, Instructor, Guidance, and Outlines. Use when extracting structured data or constraining LLM outputs.",
|
||||
"source": "./",
|
||||
"strict": false,
|
||||
"skills": [
|
||||
"./16-prompt-engineering/dspy",
|
||||
"./16-prompt-engineering/guidance",
|
||||
"./16-prompt-engineering/instructor",
|
||||
"./16-prompt-engineering/outlines"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "observability",
|
||||
"description": "LLM application monitoring including LangSmith and Phoenix. Use when debugging LLM apps or monitoring production systems.",
|
||||
"source": "./",
|
||||
"strict": false,
|
||||
"skills": [
|
||||
"./17-observability/langsmith",
|
||||
"./17-observability/phoenix"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "multimodal",
|
||||
"description": "Vision, audio, and multimodal models including CLIP, Whisper, LLaVA, BLIP-2, Segment Anything, Stable Diffusion, AudioCraft, Cosmos Policy, OpenPI, and OpenVLA-OFT. Use when working with images, audio, multimodal tasks, or vision-language-action robot policies.",
|
||||
"source": "./",
|
||||
"strict": false,
|
||||
"skills": [
|
||||
"./18-multimodal/audiocraft",
|
||||
"./18-multimodal/blip-2",
|
||||
"./18-multimodal/clip",
|
||||
"./18-multimodal/cosmos-policy",
|
||||
"./18-multimodal/llava",
|
||||
"./18-multimodal/openpi",
|
||||
"./18-multimodal/openvla-oft",
|
||||
"./18-multimodal/segment-anything",
|
||||
"./18-multimodal/stable-diffusion",
|
||||
"./18-multimodal/whisper"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "emerging-techniques",
|
||||
"description": "Advanced ML techniques including MoE Training, Model Merging, Long Context, Speculative Decoding, Knowledge Distillation, and Model Pruning. Use when implementing cutting-edge optimization or architecture techniques.",
|
||||
"source": "./",
|
||||
"strict": false,
|
||||
"skills": [
|
||||
"./19-emerging-techniques/knowledge-distillation",
|
||||
"./19-emerging-techniques/long-context",
|
||||
"./19-emerging-techniques/model-merging",
|
||||
"./19-emerging-techniques/model-pruning",
|
||||
"./19-emerging-techniques/moe-training",
|
||||
"./19-emerging-techniques/speculative-decoding"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "autoresearch",
|
||||
"description": "Autonomous research orchestration using a two-loop architecture. Manages the full research lifecycle from literature survey to paper writing, routing to domain-specific skills for execution. Use when starting a research project, running autonomous experiments, or managing multi-hypothesis research.",
|
||||
"source": "./",
|
||||
"strict": false,
|
||||
"skills": [
|
||||
"./0-autoresearch-skill"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "ml-paper-writing",
|
||||
"description": "Write publication-ready ML/AI/Systems papers for NeurIPS, ICML, ICLR, ACL, AAAI, COLM, OSDI, NSDI, ASPLOS, SOSP. Includes LaTeX templates, citation verification, reviewer guidelines, publication-quality figure generation, systems paper structural blueprints, and conference presentation slides.",
|
||||
"source": "./",
|
||||
"strict": false,
|
||||
"skills": [
|
||||
"./20-ml-paper-writing/ml-paper-writing",
|
||||
"./20-ml-paper-writing/academic-plotting",
|
||||
"./20-ml-paper-writing/systems-paper-writing",
|
||||
"./20-ml-paper-writing/presenting-conference-talks"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "ideation",
|
||||
"description": "Research ideation frameworks including structured brainstorming and creative thinking. Use when exploring new research directions, generating novel ideas, or seeking fresh angles on existing work.",
|
||||
"source": "./",
|
||||
"strict": false,
|
||||
"skills": [
|
||||
"./21-research-ideation/brainstorming-research-ideas",
|
||||
"./21-research-ideation/creative-thinking-for-research"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "agent-native-research-artifact",
|
||||
"description": "Agent-Native Research Artifact (ARA) tooling: compile any research input (paper, repo, notes) into a structured artifact, record session provenance as a post-task epilogue, and run Seal Level 2 epistemic review. Use when ingesting research into a falsifiable, agent-traversable artifact, capturing how a research project actually evolved, or auditing an ARA for evidence-claim alignment.",
|
||||
"source": "./",
|
||||
"strict": false,
|
||||
"skills": [
|
||||
"./22-agent-native-research-artifact/compiler",
|
||||
"./22-agent-native-research-artifact/research-manager",
|
||||
"./22-agent-native-research-artifact/rigor-reviewer"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
+27
@@ -0,0 +1,27 @@
|
||||
name: Claude Code
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
pull_request_review_comment:
|
||||
types: [created]
|
||||
issues:
|
||||
types: [opened, assigned]
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
issues: write
|
||||
|
||||
jobs:
|
||||
claude:
|
||||
if: |
|
||||
(github.event_name == 'issue_comment' && contains(github.event.comment.body, '@claude') && contains(fromJSON('["OWNER", "MEMBER", "COLLABORATOR"]'), github.event.comment.author_association)) ||
|
||||
(github.event_name == 'pull_request_review_comment' && contains(github.event.comment.body, '@claude') && contains(fromJSON('["OWNER", "MEMBER", "COLLABORATOR"]'), github.event.comment.author_association)) ||
|
||||
(github.event_name == 'issues' && contains(github.event.issue.body, '@claude') && contains(fromJSON('["OWNER", "MEMBER", "COLLABORATOR"]'), github.event.issue.author_association))
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
@@ -0,0 +1,73 @@
|
||||
name: Publish to npm
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
paths:
|
||||
- 'packages/ai-research-skills/**'
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
publish:
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: packages/ai-research-skills
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Check if version changed
|
||||
id: version
|
||||
run: |
|
||||
CURRENT=$(node -p "require('./package.json').version")
|
||||
PREVIOUS=$(git show HEAD~1:packages/ai-research-skills/package.json 2>/dev/null | node -p "JSON.parse(require('fs').readFileSync('/dev/stdin','utf8')).version" 2>/dev/null || echo "")
|
||||
echo "current=$CURRENT"
|
||||
echo "previous=$PREVIOUS"
|
||||
if [ "$CURRENT" != "$PREVIOUS" ]; then
|
||||
echo "changed=true" >> $GITHUB_OUTPUT
|
||||
echo "version=$CURRENT" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "changed=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Check if version already published
|
||||
if: steps.version.outputs.changed == 'true'
|
||||
id: published
|
||||
run: |
|
||||
VERSION=${{ steps.version.outputs.version }}
|
||||
if npm view @orchestra-research/ai-research-skills@$VERSION version 2>/dev/null; then
|
||||
echo "already_published=true" >> $GITHUB_OUTPUT
|
||||
echo "Version $VERSION already on npm, skipping"
|
||||
else
|
||||
echo "already_published=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Setup Node.js
|
||||
if: steps.version.outputs.changed == 'true' && steps.published.outputs.already_published == 'false'
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '24'
|
||||
registry-url: 'https://registry.npmjs.org'
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.version.outputs.changed == 'true' && steps.published.outputs.already_published == 'false'
|
||||
run: npm ci
|
||||
|
||||
- name: Publish to npm
|
||||
if: steps.version.outputs.changed == 'true' && steps.published.outputs.already_published == 'false'
|
||||
run: |
|
||||
echo "Publishing v${{ steps.version.outputs.version }} to npm..."
|
||||
unset NODE_AUTH_TOKEN
|
||||
npm config delete //registry.npmjs.org/:_authToken || true
|
||||
npm publish --access public --provenance
|
||||
|
||||
- name: Skip reason
|
||||
if: steps.version.outputs.changed != 'true'
|
||||
run: echo "Version unchanged, skipping publish"
|
||||
+199
@@ -0,0 +1,199 @@
|
||||
name: Sync Skills to Orchestra
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
workflow_dispatch: # Allow manual trigger
|
||||
|
||||
jobs:
|
||||
sync-skills:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 2 # Fetch last 2 commits to detect changes
|
||||
|
||||
- name: Detect changed skill folders
|
||||
id: changes
|
||||
run: |
|
||||
# Get list of changed files in last commit
|
||||
CHANGED_FILES=$(git diff --name-only HEAD^..HEAD)
|
||||
|
||||
echo "Changed files:"
|
||||
echo "$CHANGED_FILES"
|
||||
|
||||
# Find skill directories - supports two patterns:
|
||||
# Pattern 1: XX-category/skill-name/SKILL.md (nested skills)
|
||||
# Pattern 2: XX-category/SKILL.md (standalone skills like 20-ml-paper-writing)
|
||||
|
||||
SKILL_DIRS=""
|
||||
|
||||
# Pattern 1: Nested skills (XX-category/skill-name/)
|
||||
NESTED=$(echo "$CHANGED_FILES" | grep -E '^[0-9]{2}-[^/]+/[^/]+/' | sed -E 's|^([0-9]{2}-[^/]+/[^/]+)/.*|\1|' | sort -u)
|
||||
if [ -n "$NESTED" ]; then
|
||||
SKILL_DIRS="$NESTED"
|
||||
fi
|
||||
|
||||
# Pattern 2: Standalone skills (XX-category/ with SKILL.md directly inside)
|
||||
STANDALONE=$(echo "$CHANGED_FILES" | grep -E '^[0-9]{2}-[^/]+/SKILL\.md$' | sed -E 's|^([0-9]{2}-[^/]+)/SKILL\.md$|\1|' | sort -u)
|
||||
if [ -n "$STANDALONE" ]; then
|
||||
if [ -n "$SKILL_DIRS" ]; then
|
||||
SKILL_DIRS=$(printf "%s\n%s" "$SKILL_DIRS" "$STANDALONE" | sort -u)
|
||||
else
|
||||
SKILL_DIRS="$STANDALONE"
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "Changed skill directories:"
|
||||
echo "$SKILL_DIRS"
|
||||
|
||||
# Convert to JSON array for matrix
|
||||
if [ -z "$SKILL_DIRS" ]; then
|
||||
SKILLS_JSON="[]"
|
||||
SKILL_COUNT=0
|
||||
else
|
||||
SKILLS_JSON=$(echo "$SKILL_DIRS" | jq -R -s -c 'split("\n") | map(select(length > 0))')
|
||||
SKILL_COUNT=$(echo "$SKILL_DIRS" | grep -c . || echo "0")
|
||||
fi
|
||||
|
||||
echo "skills=$SKILLS_JSON" >> $GITHUB_OUTPUT
|
||||
echo "count=$SKILL_COUNT" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Process and sync skills
|
||||
if: steps.changes.outputs.count > 0
|
||||
env:
|
||||
ORCHESTRA_API_URL: ${{ secrets.ORCHESTRA_API_URL }}
|
||||
ORCHESTRA_SYNC_API_KEY: ${{ secrets.ORCHESTRA_SYNC_API_KEY }}
|
||||
run: |
|
||||
SKILLS='${{ steps.changes.outputs.skills }}'
|
||||
|
||||
echo "Processing $(echo $SKILLS | jq 'length') skill(s)..."
|
||||
|
||||
# Install jq for JSON processing
|
||||
sudo apt-get update && sudo apt-get install -y jq zip
|
||||
|
||||
# Loop through each skill directory
|
||||
echo "$SKILLS" | jq -r '.[]' | while read SKILL_PATH; do
|
||||
echo "==================================================="
|
||||
echo "Processing: $SKILL_PATH"
|
||||
echo "==================================================="
|
||||
|
||||
# Check if SKILL.md exists
|
||||
if [ ! -f "$SKILL_PATH/SKILL.md" ]; then
|
||||
echo "⚠️ WARNING: No SKILL.md found in $SKILL_PATH, skipping"
|
||||
continue
|
||||
fi
|
||||
|
||||
# Extract skill name from SKILL.md frontmatter
|
||||
SKILL_NAME=$(grep -A 20 "^---$" "$SKILL_PATH/SKILL.md" | grep "^name:" | head -1 | sed 's/name: *//;s/"//g;s/'\''//g' | tr -d '\r')
|
||||
|
||||
# Extract author from SKILL.md frontmatter
|
||||
AUTHOR=$(grep -A 20 "^---$" "$SKILL_PATH/SKILL.md" | grep "^author:" | head -1 | sed 's/author: *//;s/"//g;s/'\''//g' | tr -d '\r')
|
||||
|
||||
# Default values
|
||||
if [ -z "$SKILL_NAME" ]; then
|
||||
# Extract from directory name as fallback
|
||||
SKILL_NAME=$(basename "$SKILL_PATH")
|
||||
echo "⚠️ No 'name' in frontmatter, using directory name: $SKILL_NAME"
|
||||
fi
|
||||
|
||||
if [ -z "$AUTHOR" ]; then
|
||||
AUTHOR="Orchestra Research"
|
||||
echo "⚠️ No 'author' in frontmatter, defaulting to: $AUTHOR"
|
||||
fi
|
||||
|
||||
echo "Skill Name: $SKILL_NAME"
|
||||
echo "Author: $AUTHOR"
|
||||
echo "Path: $SKILL_PATH"
|
||||
|
||||
# Create temporary directory for zipping
|
||||
TEMP_DIR=$(mktemp -d)
|
||||
SKILL_DIR="$TEMP_DIR/$SKILL_NAME"
|
||||
mkdir -p "$SKILL_DIR"
|
||||
|
||||
# Copy all contents of skill directory (SKILL.md, references/, scripts/, assets/, etc.)
|
||||
cp -r "$SKILL_PATH"/* "$SKILL_DIR/" 2>/dev/null || true
|
||||
|
||||
# Create zip file (exclude hidden files and .gitkeep)
|
||||
ZIP_FILE="$TEMP_DIR/${SKILL_NAME}.zip"
|
||||
cd "$TEMP_DIR"
|
||||
zip -r "$ZIP_FILE" "$SKILL_NAME" -x "*/.*" "*/.gitkeep" "*.DS_Store"
|
||||
cd -
|
||||
|
||||
# Verify zip was created
|
||||
if [ ! -f "$ZIP_FILE" ]; then
|
||||
echo "❌ ERROR: Failed to create zip file for $SKILL_NAME"
|
||||
continue
|
||||
fi
|
||||
|
||||
echo "✓ Created zip: $(ls -lh "$ZIP_FILE" | awk '{print $5}')"
|
||||
|
||||
# Write SKILL.md content to temp file (avoid argument length limits)
|
||||
SKILL_MD_FILE="$TEMP_DIR/skill.md"
|
||||
cat "$SKILL_PATH/SKILL.md" > "$SKILL_MD_FILE"
|
||||
|
||||
# Encode zip to base64 and write to temp file (avoid argument length limits)
|
||||
ZIP_BASE64_FILE="$TEMP_DIR/base64.txt"
|
||||
base64 -w 0 "$ZIP_FILE" > "$ZIP_BASE64_FILE" 2>/dev/null || base64 "$ZIP_FILE" > "$ZIP_BASE64_FILE"
|
||||
|
||||
# Prepare JSON payload (use --rawfile for large content)
|
||||
JSON_PAYLOAD=$(jq -n \
|
||||
--arg skillName "$SKILL_NAME" \
|
||||
--arg skillPath "$SKILL_PATH" \
|
||||
--arg author "$AUTHOR" \
|
||||
--rawfile skillMdContent "$SKILL_MD_FILE" \
|
||||
--rawfile zipBase64 "$ZIP_BASE64_FILE" \
|
||||
'{
|
||||
skillName: $skillName,
|
||||
skillPath: $skillPath,
|
||||
author: $author,
|
||||
skillMdContent: $skillMdContent,
|
||||
zipBase64: $zipBase64
|
||||
}')
|
||||
|
||||
# Send to Orchestra API (write JSON to file to avoid argument length limits)
|
||||
echo "📤 Uploading to Orchestra..."
|
||||
JSON_FILE="$TEMP_DIR/payload.json"
|
||||
echo "$JSON_PAYLOAD" > "$JSON_FILE"
|
||||
|
||||
RESPONSE=$(curl -s -w "\n%{http_code}" -L \
|
||||
-X POST \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "X-Admin-API-Key: $ORCHESTRA_SYNC_API_KEY" \
|
||||
-d @"$JSON_FILE" \
|
||||
"$ORCHESTRA_API_URL/api/admin/sync-github-skill")
|
||||
|
||||
HTTP_CODE=$(echo "$RESPONSE" | tail -n1)
|
||||
BODY=$(echo "$RESPONSE" | sed '$d')
|
||||
|
||||
echo "HTTP Status: $HTTP_CODE"
|
||||
echo "Response: $BODY"
|
||||
|
||||
if [ "$HTTP_CODE" = "200" ]; then
|
||||
ACTION=$(echo "$BODY" | jq -r '.action // "synced"')
|
||||
SOURCE=$(echo "$BODY" | jq -r '.source // "unknown"')
|
||||
echo "✅ SUCCESS: Skill $SKILL_NAME $ACTION (source: $SOURCE)"
|
||||
else
|
||||
ERROR_MSG=$(echo "$BODY" | jq -r '.error // "Unknown error"')
|
||||
echo "❌ FAILED: $ERROR_MSG"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Cleanup
|
||||
rm -rf "$TEMP_DIR"
|
||||
|
||||
echo ""
|
||||
done
|
||||
|
||||
echo "==================================================="
|
||||
echo "✅ Sync completed successfully!"
|
||||
echo "==================================================="
|
||||
|
||||
- name: No changes detected
|
||||
if: steps.changes.outputs.count == 0
|
||||
run: |
|
||||
echo "ℹ️ No skill changes detected in this commit"
|
||||
echo "Only commits that modify skill directories will trigger sync"
|
||||
@@ -0,0 +1,103 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
|
||||
# LaTeX auxiliary files
|
||||
*.aux
|
||||
*.bbl
|
||||
*.blg
|
||||
*.out
|
||||
*.fls
|
||||
*.fdb_latexmk
|
||||
*.synctex.gz
|
||||
*.toc
|
||||
*.lof
|
||||
*.lot
|
||||
*.nav
|
||||
*.snm
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
*.manifest
|
||||
*.spec
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Virtual environments
|
||||
venv/
|
||||
ENV/
|
||||
env/
|
||||
.venv
|
||||
|
||||
# IDEs
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
.DS_Store
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
*.ipynb
|
||||
|
||||
# Pytest
|
||||
.pytest_cache/
|
||||
.coverage
|
||||
htmlcov/
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# ML/Data
|
||||
*.h5
|
||||
*.pkl
|
||||
*.pth
|
||||
*.ckpt
|
||||
*.safetensors
|
||||
wandb/
|
||||
runs/
|
||||
outputs/
|
||||
checkpoints/
|
||||
*.log
|
||||
|
||||
# Environment variables
|
||||
.env
|
||||
.env.local
|
||||
|
||||
# Temporary files
|
||||
tmp/
|
||||
temp/
|
||||
*.tmp
|
||||
|
||||
# Skill Seeker metadata and build artifacts
|
||||
.metadata/
|
||||
*_data/
|
||||
!dev_data/
|
||||
*_github_data.json
|
||||
*_extracted.json
|
||||
output/
|
||||
*.zip
|
||||
0-autoresearch-skill/background_docs/
|
||||
0-autoresearch-skill/twitter_thread_draft.md
|
||||
0-autoresearch-skill/social_posts.md
|
||||
0-autoresearch-skill/image_generation_brief.md
|
||||
0-autoresearch-skill/autoresearch-loops-image.png
|
||||
@@ -0,0 +1,411 @@
|
||||
---
|
||||
name: autoresearch
|
||||
description: Orchestrates end-to-end autonomous AI research projects using a two-loop architecture. The inner loop runs rapid experiment iterations with clear optimization targets. The outer loop synthesizes results, identifies patterns, and steers research direction. Routes to domain-specific skills for execution, supports continuous agent operation via Claude Code /loop and OpenClaw heartbeat, and produces research presentations and papers. Use when starting a research project, running autonomous experiments, or managing a multi-hypothesis research effort.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Autonomous Research, Two-Loop Architecture, Experiment Orchestration, Research Synthesis, Project Management]
|
||||
---
|
||||
|
||||
# Autoresearch
|
||||
|
||||
Autonomous research orchestration for AI coding agents. You manage the full research lifecycle — from literature survey to published paper — by maintaining structured state, running a two-loop experiment-synthesis cycle, and routing to domain-specific skills for execution.
|
||||
|
||||
You are a research project manager, not a domain expert. You orchestrate; the domain skills execute.
|
||||
|
||||
**This runs fully autonomously.** Do not ask the user for permission or confirmation — use your best judgment and keep moving. Show the human your progress frequently through research presentations (HTML/PDF) so they can see what you're doing and redirect if needed. The human is asleep or busy; your job is to make as much research progress as possible on your own.
|
||||
|
||||
## Getting Started
|
||||
|
||||
Users arrive in different states. Determine which and proceed:
|
||||
|
||||
| User State | What to Do |
|
||||
|---|---|
|
||||
| Vague idea ("I want to explore X") | Brief discussion to clarify, then bootstrap |
|
||||
| Clear research question | Bootstrap directly |
|
||||
| Existing plan or proposal | Review plan, set up workspace, enter loops |
|
||||
| Resuming (research-state.yaml exists) | Read state, continue from where you left off |
|
||||
|
||||
If things are clear, don't over-discuss — proceed to full autoresearch. Most users want you to just start researching.
|
||||
|
||||
**Step 0 — before anything else**: Set up the agent continuity loop. See [Agent Continuity](#agent-continuity-mandatory--set-up-first). This is MANDATORY. Without it, the research stops after one cycle.
|
||||
|
||||
### Initialize Workspace
|
||||
|
||||
Create this structure at the project root:
|
||||
|
||||
```
|
||||
{project}/
|
||||
├── research-state.yaml # Central state tracking
|
||||
├── research-log.md # Decision timeline
|
||||
├── findings.md # Evolving narrative synthesis
|
||||
├── literature/ # Papers, survey notes
|
||||
├── src/ # Reusable code (utils, plotting, shared modules)
|
||||
├── data/ # Raw result data (CSVs, JSONs, checkpoints)
|
||||
├── experiments/ # Per-hypothesis work
|
||||
│ └── {hypothesis-slug}/
|
||||
│ ├── protocol.md # What, why, and prediction
|
||||
│ ├── code/ # Experiment-specific code
|
||||
│ ├── results/ # Raw outputs, metrics, logs
|
||||
│ └── analysis.md # What we learned
|
||||
├── to_human/ # Progress presentations and reports for human review
|
||||
└── paper/ # Final paper (via ml-paper-writing)
|
||||
```
|
||||
|
||||
- **`src/`**: When you write useful code (plotting functions, data loaders, evaluation helpers), move it here so it can be reused across experiments. Don't duplicate code in every experiment directory.
|
||||
- **`data/`**: Save raw result data (metric CSVs, training logs, small outputs) here in a structured way. After a long research horizon, you'll need this to replot, reanalyze, and write up the paper properly. Name files descriptively (e.g., `trajectory_H1_runs001-010.csv`). Large files like model checkpoints should go to a separate storage path (e.g., `/data/`, cloud storage, or wherever the user's compute environment stores artifacts) — not in the project directory.
|
||||
|
||||
Initialize `research-state.yaml`, `research-log.md`, and `findings.md` from [templates/](templates/). Adapt the workspace as the project evolves — this is a starting point, not a rigid requirement.
|
||||
|
||||
## The Two-Loop Architecture
|
||||
|
||||
This is the core engine. Everything else supports it.
|
||||
|
||||
```
|
||||
BOOTSTRAP (once, lightweight)
|
||||
Scope question → search literature → form initial hypotheses
|
||||
|
||||
INNER LOOP (fast, autonomous, repeating)
|
||||
Pick hypothesis → experiment → measure → record → learn → next
|
||||
Goal: run constrained experiments with clear measurable outcomes
|
||||
|
||||
OUTER LOOP (periodic, reflective)
|
||||
Review results → find patterns → update findings.md →
|
||||
new hypotheses → decide direction
|
||||
Goal: synthesize understanding, find the story — this is where novelty comes from
|
||||
|
||||
FINALIZE (when concluding)
|
||||
Write paper via ml-paper-writing → final presentation → archive
|
||||
```
|
||||
|
||||
The inner loop runs tight experiment cycles with clear measurable outcomes. This could be optimizing a benchmark (make val_loss go down) OR testing mechanistic hypotheses (does intervention X cause effect Y?). The outer loop steps back to ask: what do these results *mean*? What patterns emerge? What's the story? Research is open-ended — the two loops let you both optimize and discover.
|
||||
|
||||
There is no rigid boundary between the two loops — you decide when enough inner loop results have accumulated to warrant reflection. Typically every 5-10 experiments, or when you notice a pattern, or when progress stalls. The agent's judgment drives the rhythm.
|
||||
|
||||
### Research is Non-Linear
|
||||
|
||||
The two-loop structure is a rhythm, not a railroad. At any point during research you can and should:
|
||||
|
||||
- **Return to literature** when results surprise you, assumptions break, or you need context for a new direction — always save what you find to `literature/`
|
||||
- **Brainstorm new ideas** using `21-research-ideation/` skills when you're stuck or when results open unexpected questions
|
||||
- **Pivot the question entirely** if experiments reveal the original question was wrong or less interesting than what you found
|
||||
|
||||
This is normal. Most real research projects loop back to literature 1-3 times and generate new hypotheses mid-stream. Don't treat bootstrap as the only time you read papers or brainstorm — do it whenever understanding would help.
|
||||
|
||||
## Bootstrap: Literature and Hypotheses
|
||||
|
||||
Before entering the loops, understand the landscape. Keep this efficient — the goal is to start experimenting, not to produce an exhaustive survey.
|
||||
|
||||
1. **Search literature** for the research question. Use multiple sources — never stop at one:
|
||||
- **Exa MCP** (`web_search_exa`) if available — best for broad discovery and finding relevant papers quickly
|
||||
- **Semantic Scholar** (`pip install semanticscholar`) — best for ML/AI papers, citation graphs, and specific paper lookup. See `20-ml-paper-writing` skill's `references/citation-workflow.md` for complete API code examples
|
||||
- **arXiv** (`pip install arxiv`) — best for recent preprints and open-access papers
|
||||
- **CrossRef** — best for DOI lookup and BibTeX retrieval
|
||||
- Keep searching until you have good coverage. If one source comes up empty, try another with different keywords
|
||||
|
||||
**Save everything to `literature/`**: For every paper you find, save a summary to `literature/` — title, authors, year, key findings, relevance to your question, and the URL/DOI. Create one file per paper and a running `literature/survey.md` with all summaries. This is your reference library — you and future sessions will need it throughout the project.
|
||||
|
||||
2. **Identify gaps** from the literature
|
||||
- What's been tried? What hasn't? Where do existing methods break?
|
||||
- What do Discussion sections flag as future work?
|
||||
|
||||
3. **Form initial hypotheses** — invoke `21-research-ideation/` skills
|
||||
- `brainstorming-research-ideas` for structured diverge-converge workflow
|
||||
- `creative-thinking-for-research` for deeper cognitive frameworks
|
||||
- Each hypothesis must be testable with a clear prediction
|
||||
|
||||
4. **Define the evaluation**
|
||||
- Set the proxy metric and baseline before running experiments
|
||||
- The metric should be computable quickly (minutes, not hours)
|
||||
- Lock evaluation criteria upfront to prevent unconscious metric gaming
|
||||
|
||||
5. **Record** in research-state.yaml, log the bootstrap in research-log.md
|
||||
|
||||
## The Inner Loop
|
||||
|
||||
Rapid iteration with clear measurable outcomes. Two flavors:
|
||||
|
||||
- **Optimization**: make a metric go up/down (val_loss, accuracy, throughput). Think Karpathy's autoresearch.
|
||||
- **Discovery**: test mechanistic hypotheses about why something works. The metric is a measurement (does grokking happen faster? does entropy increase before forgetting?), not just a target to optimize.
|
||||
|
||||
```
|
||||
1. Pick the highest-priority untested hypothesis
|
||||
2. Write a protocol: what change, what prediction, why
|
||||
Lock it: commit to git BEFORE running (research(protocol): {hypothesis})
|
||||
This creates temporal proof your plan existed before results
|
||||
3. Run the experiment (invoke the relevant domain skill)
|
||||
4. Sanity check before trusting results:
|
||||
- Did training converge? No NaN/Inf?
|
||||
- Does baseline reproduce expected performance?
|
||||
- Data loading correct? (spot-check a few samples)
|
||||
5. Measure the proxy metric
|
||||
6. Record in experiments/{hypothesis-slug}/
|
||||
Label clearly: CONFIRMATORY (in your protocol) vs EXPLORATORY (discovered during execution)
|
||||
7. If positive: keep, note WHY it worked
|
||||
8. If negative: this is progress — note what it rules out and what it suggests
|
||||
9. Update research-state.yaml
|
||||
10. If stuck: search literature or invoke ideation skills — don't just keep trying random things
|
||||
```
|
||||
|
||||
**Never stop.** Even if something fails, find a path forward. Debug, adjust, simplify, or pivot — but keep the research moving. The `/loop` and heartbeat mechanisms will keep you going; use that momentum.
|
||||
|
||||
### Route to Domain Skills
|
||||
|
||||
When you need domain-specific execution, search the skills library:
|
||||
|
||||
| Research Activity | Look In |
|
||||
|---|---|
|
||||
| Data preparation | `05-data-processing/` |
|
||||
| Model training / fine-tuning | `01-model-architecture/`, `03-fine-tuning/`, `06-post-training/` |
|
||||
| Distributed training | `08-distributed-training/` |
|
||||
| Optimization (quantization, attention) | `10-optimization/` |
|
||||
| Evaluation / benchmarks | `11-evaluation/` |
|
||||
| Inference / serving | `12-inference-serving/` |
|
||||
| Interpretability analysis | `04-mechanistic-interpretability/` |
|
||||
| Experiment tracking (W&B, MLflow) | `13-mlops/` |
|
||||
| Cloud compute | `09-infrastructure/` |
|
||||
|
||||
Read the relevant SKILL.md before starting — it has workflows, common issues, and code examples. See [references/skill-routing.md](references/skill-routing.md) for a complete guide.
|
||||
|
||||
### Track the Experiment Trajectory
|
||||
|
||||
Maintain a running record of measurable outcomes across experiments:
|
||||
|
||||
```json
|
||||
{
|
||||
"experiment_id": "run_014",
|
||||
"hypothesis": "H3",
|
||||
"metric_value": 0.847,
|
||||
"baseline": 0.812,
|
||||
"delta": "+0.035",
|
||||
"wall_time_min": 23,
|
||||
"change_summary": "Added cosine annealing warmup schedule"
|
||||
}
|
||||
```
|
||||
|
||||
This trajectory produces the optimization plot (like Karpathy's progress chart) — include it in progress reports. Humans love seeing the upward curve.
|
||||
|
||||
## The Outer Loop
|
||||
|
||||
Step back from individual experiments. Synthesize.
|
||||
|
||||
```
|
||||
1. Review all results since last reflection
|
||||
2. Cluster by type: what kinds of changes worked? Which didn't?
|
||||
3. Ask WHY — identify the mechanism behind successes and failures
|
||||
4. Update findings.md with current understanding
|
||||
5. Search literature if results were surprising or assumptions need revisiting
|
||||
6. Generate new hypotheses if warranted (invoke 21-research-ideation/ skills)
|
||||
7. Decide direction (see criteria below)
|
||||
8. Update research-state.yaml with new direction
|
||||
9. Log the reflection in research-log.md
|
||||
10. If there's something meaningful, generate a progress presentation
|
||||
```
|
||||
|
||||
### Deciding Direction
|
||||
|
||||
Don't just pick randomly — use these criteria:
|
||||
|
||||
**DEEPEN** — a supported result raises follow-up questions
|
||||
- Does the effect hold under different conditions? What's the mechanism?
|
||||
- Action: generate sub-hypotheses (H1.1, H1.2) → back to inner loop
|
||||
|
||||
**BROADEN** — current results are solid, but adjacent questions are untested
|
||||
- New questions emerged. The current contribution is clear but more is possible.
|
||||
- Action: generate new root hypotheses → back to inner loop
|
||||
|
||||
**PIVOT** — results invalidate key assumptions or something more interesting appeared
|
||||
- A core assumption was wrong, or an unexpected finding is more promising than the original question.
|
||||
- Action: return to literature with new questions → re-bootstrap
|
||||
|
||||
**CONCLUDE** — sufficient evidence for a contribution
|
||||
- At least one hypothesis is strongly supported (or a coherent set of negative results)
|
||||
- Key ablations completed, error analysis done
|
||||
- findings.md reads like a paper backbone — a human could write the abstract from it
|
||||
- No critical open questions that would change the story
|
||||
|
||||
Note: coherent negative results are a valid contribution. "X does NOT work because Y" is publishable if the reasoning is rigorous.
|
||||
|
||||
### findings.md Is Your Project Memory
|
||||
|
||||
This file serves two purposes: it's the research narrative for humans AND your accumulated knowledge base as an agent. Read it at the start of every session, /loop tick, or heartbeat to remember what you've learned.
|
||||
|
||||
After every outer loop, update it to answer:
|
||||
|
||||
- What do we know so far? (Current Understanding)
|
||||
- What patterns explain our results? (Patterns and Insights)
|
||||
- What specific things did we learn not to repeat? (Lessons and Constraints)
|
||||
- What remains open? (Open Questions)
|
||||
|
||||
The "Lessons and Constraints" section is especially important — it captures specific actionable learnings like "weight decay > 0.1 diverges at this scale" or "baseline only reproduces with batch_size=64." This prevents the agent from repeating failed approaches across sessions.
|
||||
|
||||
**Quality test**: After 30 inner loop experiments, a human should be able to read findings.md and write a paper abstract from it. If they can't, the outer loop isn't synthesizing — it's just logging.
|
||||
|
||||
## Agent Continuity (MANDATORY — Set Up First)
|
||||
|
||||
**Before doing anything else**, set up the wall-clock loop. This is what keeps the research running continuously. Without it, the agent stops after one cycle.
|
||||
|
||||
### Claude Code
|
||||
|
||||
**First action — run this immediately:**
|
||||
|
||||
```
|
||||
/loop 20m Continue autoresearch. Read research-state.yaml and findings.md. Re-read the autoresearch SKILL.md occasionally to stay aligned. Step back and reflect holistically — is the research making real progress? Are you deepening understanding or just running experiments? If stalling, pivot or search literature for new ideas. Keep making research progress — never idle, never stop. Update findings.md, research-log.md, and research-state.yaml when there's new progress. Git commit periodically and clean up the repo if needed. Show the human your research progress with key plots and findings by preparing a report in to_human/ and opening the HTML/PDF. Only when you believe the research is truly complete, invoke the ml-paper-writing skill to write the paper.
|
||||
```
|
||||
|
||||
This fires every 20 minutes regardless of what's happening. It's a rhythm that keeps you working — not a research phase boundary. If your previous work isn't done, just continue it.
|
||||
|
||||
### OpenClaw
|
||||
|
||||
**First action — set up a 20-minute cron job in the current session. You MUST do this before any research work.**
|
||||
|
||||
Use the `cron.add` tool to create a recurring job bound to this chat session:
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "autoresearch-loop",
|
||||
"schedule": { "kind": "every", "everyMs": 1200000 },
|
||||
"sessionTarget": "current",
|
||||
"payload": {
|
||||
"kind": "agentTurn",
|
||||
"message": "Continue autoresearch. Read research-state.yaml and findings.md. Re-read the autoresearch SKILL.md occasionally to stay aligned. Step back and reflect holistically — is the research making real progress? Are you deepening understanding or just running experiments? If stalling, pivot or search literature for new ideas. Keep making research progress — never idle, never stop. Update findings.md, research-log.md, and research-state.yaml when there's new progress. Git commit periodically and clean up the repo if needed. Show the human your research progress with key plots and findings by preparing a PDF report in to_human/ and sending it to the user via Telegram, WhatsApp, or Slack. When you get an exciting result or interesting plot, text it to the user right away — don't wait for a full report. Only when you believe the research is truly complete, invoke the ml-paper-writing skill to write the paper."
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Key details:
|
||||
- `sessionTarget: "current"` binds the cron to this chat session so it maintains conversation context across ticks
|
||||
- `everyMs: 1200000` = 20 minutes
|
||||
- After creation, verify with `cron.list` that the job exists and is enabled
|
||||
- If the cron fires while you're mid-experiment, just continue — the tick is a nudge, not a restart
|
||||
|
||||
### What the Loop Does
|
||||
|
||||
The `/loop` and cron job are purely **wall-clock rhythm**. They are completely separate from your research loops (inner/outer). On each tick:
|
||||
|
||||
1. Read `research-state.yaml` and `findings.md` — remember where you are
|
||||
2. Check if anything is broken (failed experiments, stalled training, errors)
|
||||
3. If on track → keep working on whatever you were doing
|
||||
4. If stuck or something's wrong → step back, diagnose, fix, then continue
|
||||
5. Never idle. Always be making progress.
|
||||
|
||||
## Progress Reporting
|
||||
|
||||
When you have something meaningful to share, create a research presentation — not just a status dashboard, but a compelling story.
|
||||
|
||||
**When to report** (your judgment):
|
||||
- After an outer loop that found a significant pattern
|
||||
- When the optimization trajectory shows clear progress (include the plot!)
|
||||
- After a pivot in direction
|
||||
- Before requesting human input on a decision
|
||||
- When concluding
|
||||
|
||||
**What to include** (adapt to what's compelling):
|
||||
- The research question and why it matters
|
||||
- Key results with visualizations (plots, metric tables)
|
||||
- The optimization trajectory chart (metric over experiments)
|
||||
- What was tried and why (selective, not exhaustive)
|
||||
- Current understanding (the findings narrative)
|
||||
- What's planned next
|
||||
|
||||
For Claude Code: generate HTML and `open` it. If HTML fails to open or render, convert to PDF as fallback (use `weasyprint`, `playwright pdf`, or `wkhtmltopdf`). For OpenClaw: generate PDF directly.
|
||||
|
||||
See [references/progress-reporting.md](references/progress-reporting.md) for template scaffolding and the optimization plot approach. Use the template as a starting point — be creative with what you show.
|
||||
|
||||
## Git Protocol
|
||||
|
||||
Commit at natural research milestones:
|
||||
|
||||
| When | Message Pattern |
|
||||
|---|---|
|
||||
| Workspace initialized | `research(init): {project} — {question}` |
|
||||
| Experiment protocol locked | `research(protocol): {hypothesis}` |
|
||||
| Significant results | `research(results): {hypothesis} — {outcome}` |
|
||||
| Outer loop direction change | `research(reflect): {direction} — {reason}` |
|
||||
| Paper draft complete | `research(paper): {title}` |
|
||||
|
||||
**Hard rule**: Protocol commits MUST precede result commits. Never combine them. The git history is your lightweight pre-registration — it proves what you planned before you saw results. Don't commit after every experiment — commit when there's meaningful progress.
|
||||
|
||||
## Concluding: Paper Writing
|
||||
|
||||
When the outer loop decides to CONCLUDE:
|
||||
|
||||
1. Ensure findings.md has a clear, well-supported narrative
|
||||
2. Study 2-3 top related papers to learn their format, style, and section structure
|
||||
3. Invoke the `20-ml-paper-writing` skill — it has LaTeX templates for NeurIPS, ICML, ICLR, ACL, AAAI, COLM, and systems venues
|
||||
4. Feed it the accumulated literature, experimental results, and findings
|
||||
5. Follow its citation verification workflow — never hallucinate references
|
||||
6. Generate a final comprehensive research presentation
|
||||
|
||||
Proceed autonomously through the writing process. If the ml-paper-writing skill suggests human collaboration points, adapt and keep going — produce the best draft you can. The human will review and provide feedback.
|
||||
|
||||
## Research Discipline
|
||||
|
||||
Principles to enforce continuously — not tied to any specific phase:
|
||||
|
||||
- **Lock before you run**: Commit your experiment protocol to git before executing. This proves your plan existed before you saw results. Never combine protocol + results in one commit.
|
||||
- **Confirmatory vs exploratory**: Results matching your locked protocol are confirmatory. Everything else is exploratory — interesting but requiring more skepticism.
|
||||
- **Negative results are progress**: A refuted hypothesis tells you something. Log what it rules out and what it suggests. Don't treat it as failure.
|
||||
- **Sanity check before analysis**: Verify training converged, baselines reproduce, and data is correct before trusting your primary metric.
|
||||
- **Return to literature when confused**: Don't guess — search. If results surprise you or assumptions break, go find papers. Use Exa MCP for discovery, Semantic Scholar for specific ML/AI paper lookup, arXiv for preprints.
|
||||
- **Never stop**: Don't wait for human approval on routine decisions. If a skill or tool suggests collaboration, adapt and keep going. Find the best path forward autonomously. The human will see your progress reports and can redirect if needed.
|
||||
- **Use whatever compute is available**: Adapt to the user's environment — local GPU, cluster job submission, cloud instances, or just CPU. If no GPU is available, use CPU and adjust experiment scale accordingly. Don't block on compute availability.
|
||||
|
||||
## Quality Standards
|
||||
|
||||
**Good agent behavior:**
|
||||
- Hypotheses have mechanistic reasoning ("X because Y, predicting Z"), not just "try X"
|
||||
- findings.md builds a coherent narrative, not a flat list of results
|
||||
- Negative results are recorded with what they rule out
|
||||
- The agent updates its model when experiments contradict expectations
|
||||
- Progress reports tell a research story with compelling visualizations
|
||||
|
||||
**Bad agent behavior:**
|
||||
- Pure hyperparameter sweeps without interpretation
|
||||
- findings.md is just experiment logs copy-pasted
|
||||
- Agent never revisits its assumptions after failures
|
||||
- Optimizing metrics without understanding why changes work
|
||||
|
||||
## When to Use vs Alternatives
|
||||
|
||||
**Use autoresearch when:**
|
||||
- You have a research question explorable through experiments
|
||||
- There's a measurable proxy metric for inner loop optimization
|
||||
- The real contribution requires synthesis beyond the metric
|
||||
- You want continuous autonomous research operation
|
||||
|
||||
**Use individual domain skills instead when:**
|
||||
- You have a specific one-off task (train a model, run eval, write a paper)
|
||||
- No iterative experimentation needed
|
||||
|
||||
## Common Issues
|
||||
|
||||
**Inner loop stalls (no metric improvement)**
|
||||
Run an outer loop. Is the metric the right one? Is the search space exhausted? Consider broadening or pivoting. Search literature for new approaches.
|
||||
|
||||
**Stuck and not making progress**
|
||||
Don't keep trying random changes. Step back: search literature for related work, invoke `21-research-ideation/` brainstorming skills, or run an outer loop reflection. Being stuck means you need new information or a new perspective, not more experiments.
|
||||
|
||||
**Results contradict baseline expectations**
|
||||
Investigate, don't ignore. Return to literature — your protocol might have an error, the published baseline may be wrong, or conditions differ. Update findings.md with what you learn.
|
||||
|
||||
**Agent loses context between ticks**
|
||||
Ensure research-state.yaml and findings.md are updated after every action. These files are your memory across sessions.
|
||||
|
||||
**Can't find relevant papers**
|
||||
Try multiple approaches in order: Exa MCP for broad search, Semantic Scholar for specific ML/AI paper lookup (`pip install semanticscholar`), arXiv for preprints (`pip install arxiv`). Check `20-ml-paper-writing` skill's `references/citation-workflow.md` for complete API code. Note: Google Scholar has no official API — use Semantic Scholar instead for programmatic search.
|
||||
|
||||
**No GPU available**
|
||||
Use CPU and scale experiments down. Many research tasks (analysis, interpretability, small model training) run fine on CPU. Adjust experiment design to fit available compute rather than blocking.
|
||||
|
||||
**Experiments take longer than /loop interval**
|
||||
Normal. On the next tick, check if it finished. If not, keep waiting or do something else useful (update notes, search papers). Adjust interval if needed.
|
||||
|
||||
**Not sure when to conclude**
|
||||
Three questions: Do you have a strongly supported finding? Can you explain WHY it works? Would findings.md make a convincing paper abstract? If yes to all: conclude.
|
||||
|
||||
## Advanced Topics
|
||||
|
||||
- **Detailed agent continuity**: [references/agent-continuity.md](references/agent-continuity.md)
|
||||
- **Progress presentation templates**: [references/progress-reporting.md](references/progress-reporting.md)
|
||||
- **Complete skill routing**: [references/skill-routing.md](references/skill-routing.md)
|
||||
@@ -0,0 +1,94 @@
|
||||
# Agent Continuity: Keeping Research Running
|
||||
|
||||
Autonomous research requires agents that keep working continuously — hours, overnight, or days. This reference covers how to set up the wall-clock loop on each platform.
|
||||
|
||||
## MANDATORY: Set Up Before Anything Else
|
||||
|
||||
The wall-clock loop is the first thing you do when starting autoresearch. Without it, the agent does one cycle of work and stops. The loop fires every 20 minutes and simply tells the agent: "keep working, check if anything's wrong."
|
||||
|
||||
This is completely separate from the research inner/outer loops. The wall-clock loop is a fixed-interval prompt injection. The research loops run at experiment speed (minutes to hours per cycle). They will not sync and don't need to.
|
||||
|
||||
## Claude Code: /loop
|
||||
|
||||
**Run this as your very first action:**
|
||||
|
||||
```
|
||||
/loop 20m Continue autoresearch. Read research-state.yaml and findings.md. Check: is the current experiment done? Are there errors? Is progress stalling? If everything is on track, keep working. If something is wrong, step back and fix it. Never idle.
|
||||
```
|
||||
|
||||
### What Happens on Each Tick
|
||||
|
||||
The `/loop` message fires every 20 minutes. When it fires:
|
||||
|
||||
1. **Read state** — `research-state.yaml` and `findings.md` to remember where you are and what you've learned
|
||||
2. **Check health** — is the current experiment running? Did it error? Is training diverging?
|
||||
3. **If on track** → continue whatever you were doing (don't restart or re-plan)
|
||||
4. **If stuck or broken** → step back, diagnose the issue, fix it, then continue
|
||||
5. **If idle** → you shouldn't be. Pick the next hypothesis and start an experiment
|
||||
|
||||
The tick is a nudge, not a phase boundary. If the previous tick's work isn't done, just keep going.
|
||||
|
||||
### When Experiments Run Longer Than 20 Minutes
|
||||
|
||||
This is normal and expected. On the next tick:
|
||||
- Check if the experiment finished
|
||||
- If still running: monitor for errors, do lightweight work (update notes, search papers)
|
||||
- Don't restart or duplicate the experiment
|
||||
|
||||
### Context Recovery
|
||||
|
||||
Your primary memory between ticks is `research-state.yaml` and `findings.md`. Keep them updated after every experiment and reflection. If you lose context, read both files plus `research-log.md` to reconstruct where you are.
|
||||
|
||||
## OpenClaw: Cron Job
|
||||
|
||||
**Set up a 20-minute cron job in the current session as your very first action.** Use the `cron.add` tool to create a recurring job bound to the current chat session:
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "autoresearch-loop",
|
||||
"schedule": { "kind": "every", "everyMs": 1200000 },
|
||||
"sessionTarget": "current",
|
||||
"payload": {
|
||||
"kind": "agentTurn",
|
||||
"message": "Continue autoresearch. Read research-state.yaml and findings.md. Check: is the current experiment done? Are there errors? Is progress stalling? If everything is on track, keep working. If something is wrong, step back and fix it. Never idle."
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Key details:
|
||||
- `sessionTarget: "current"` binds the cron to the current chat session (resolved to `session:<sessionKey>` at creation time), so it maintains conversation context across ticks
|
||||
- `everyMs: 1200000` = 20 minutes
|
||||
- Verify with `cron.list` that the job is created and enabled
|
||||
- To check run history later: `cron.runs` with the job ID
|
||||
|
||||
### Context Between Cron Ticks
|
||||
|
||||
OpenClaw cron invocations may start fresh each time. Your workspace files are your memory:
|
||||
|
||||
- `research-state.yaml` — where you are, what's active
|
||||
- `findings.md` — what you've learned (read this every time!)
|
||||
- `research-log.md` — what happened chronologically
|
||||
|
||||
Keep these updated after every action so the next cron tick can pick up seamlessly.
|
||||
|
||||
### Progress Reports
|
||||
|
||||
OpenClaw can't `open` HTML files locally like Claude Code can. When you have something to report:
|
||||
|
||||
1. Generate a PDF progress summary (use Python with reportlab, matplotlib, or similar)
|
||||
2. Include: research question, key results, optimization trajectory plot, current understanding, next steps
|
||||
3. Send it to the user via Telegram, WhatsApp, or Slack — whichever channel they use
|
||||
4. When you get an exciting result or interesting plot, send it right away — don't wait for a full report
|
||||
|
||||
## Research State as Ground Truth
|
||||
|
||||
Both platforms share the same ground truth: the workspace files.
|
||||
|
||||
| File | Purpose | Update Frequency |
|
||||
|---|---|---|
|
||||
| `research-state.yaml` | Machine-readable state | After every experiment and reflection |
|
||||
| `research-log.md` | Decision timeline | After every significant action |
|
||||
| `findings.md` | Narrative understanding + project memory | After every outer loop |
|
||||
| `experiments/*/results/` | Raw experimental data | After every experiment |
|
||||
|
||||
The wall-clock loop (`/loop` or cron) is just the trigger. The workspace files are the memory. Keep them current.
|
||||
@@ -0,0 +1,165 @@
|
||||
# Progress Reporting: Research Presentations
|
||||
|
||||
When the research produces something worth sharing, create a compelling presentation — not a status dump, but a research story with visuals.
|
||||
|
||||
## When to Report
|
||||
|
||||
You decide when progress is meaningful enough to report. Consider reporting:
|
||||
|
||||
- After an outer loop reflection that identified a significant pattern
|
||||
- When the optimization trajectory shows clear, sustained improvement
|
||||
- After a pivot — explain why the direction changed
|
||||
- Before requesting human input on a major decision
|
||||
- When concluding the research, before paper writing
|
||||
|
||||
Maximum frequency: once per /loop tick or heartbeat cycle. Minimum: whenever you have something a human would find interesting.
|
||||
|
||||
## What Makes a Good Research Presentation
|
||||
|
||||
A good progress report reads like a research talk, not a database query. It should:
|
||||
|
||||
1. **Tell a story**: why we started, what we tried, what we found, what it means
|
||||
2. **Show, don't just tell**: include plots, tables, comparisons — not just text
|
||||
3. **Be selective**: highlight the interesting findings, don't exhaustively list every experiment
|
||||
4. **End with direction**: what happens next and why
|
||||
|
||||
## Recommended Sections
|
||||
|
||||
Adapt these to what's compelling from your current research. Skip sections that aren't relevant. Add sections the research demands.
|
||||
|
||||
### 1. Research Question and Motivation
|
||||
- What are we investigating and why does it matter?
|
||||
- One paragraph, accessible to someone unfamiliar with the project
|
||||
|
||||
### 2. Approach
|
||||
- What's our method? What are we optimizing?
|
||||
- The two-loop architecture in one sentence
|
||||
|
||||
### 3. Optimization Trajectory (The Karpathy Plot)
|
||||
- X-axis: experiment number or wall-clock time
|
||||
- Y-axis: proxy metric value
|
||||
- Show baseline as a horizontal line
|
||||
- Annotate significant jumps with what change caused them
|
||||
- This is often the most compelling visual — include it whenever possible
|
||||
|
||||
### 4. Key Findings
|
||||
- The 2-3 most significant results with supporting evidence
|
||||
- Include plots, metric tables, comparison charts
|
||||
- Explain WHY results are significant, not just WHAT they are
|
||||
|
||||
### 5. What We Tried (Decision Map)
|
||||
- A selective view of the hypothesis tree
|
||||
- Focus on the reasoning: why each direction was chosen, what it taught us
|
||||
- Include both successes and informative failures
|
||||
|
||||
### 6. Current Understanding
|
||||
- The findings.md narrative, but presented compellingly
|
||||
- What's our best explanation for the patterns we see?
|
||||
|
||||
### 7. Next Steps
|
||||
- What experiments are planned and why
|
||||
- What questions remain open
|
||||
- Any decisions that need human input
|
||||
|
||||
## The Optimization Trajectory Plot
|
||||
|
||||
This is the signature visual of autoresearch — a chart showing metric improvement over experiments.
|
||||
|
||||
Minimal implementation (SVG-based, no dependencies):
|
||||
|
||||
```python
|
||||
def generate_trajectory_svg(trajectory_data, width=800, height=400):
|
||||
"""Generate an SVG optimization trajectory chart.
|
||||
|
||||
trajectory_data: list of {"run": int, "metric": float, "label": str}
|
||||
"""
|
||||
if not trajectory_data:
|
||||
return "<p>No experiments yet.</p>"
|
||||
|
||||
metrics = [d["metric"] for d in trajectory_data]
|
||||
min_m, max_m = min(metrics), max(metrics)
|
||||
margin = (max_m - min_m) * 0.1 or 0.1
|
||||
y_min, y_max = min_m - margin, max_m + margin
|
||||
|
||||
padding = 60
|
||||
plot_w = width - 2 * padding
|
||||
plot_h = height - 2 * padding
|
||||
n = len(trajectory_data)
|
||||
|
||||
def x_pos(i):
|
||||
return padding + (i / max(n - 1, 1)) * plot_w
|
||||
|
||||
def y_pos(v):
|
||||
return padding + plot_h - ((v - y_min) / (y_max - y_min)) * plot_h
|
||||
|
||||
# Build SVG
|
||||
svg = f'<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">'
|
||||
svg += f'<rect width="{width}" height="{height}" fill="#1a1a2e" rx="8"/>'
|
||||
|
||||
# Grid lines
|
||||
for i in range(5):
|
||||
y = padding + i * plot_h / 4
|
||||
val = y_max - i * (y_max - y_min) / 4
|
||||
svg += f'<line x1="{padding}" y1="{y}" x2="{width-padding}" y2="{y}" stroke="#333" stroke-dasharray="4"/>'
|
||||
svg += f'<text x="{padding-8}" y="{y+4}" fill="#888" text-anchor="end" font-size="11">{val:.3f}</text>'
|
||||
|
||||
# Baseline line
|
||||
baseline = trajectory_data[0]["metric"]
|
||||
by = y_pos(baseline)
|
||||
svg += f'<line x1="{padding}" y1="{by}" x2="{width-padding}" y2="{by}" stroke="#ff6b6b" stroke-dasharray="6" opacity="0.7"/>'
|
||||
svg += f'<text x="{width-padding+5}" y="{by+4}" fill="#ff6b6b" font-size="10">baseline</text>'
|
||||
|
||||
# Data line
|
||||
points = " ".join(f"{x_pos(i)},{y_pos(d['metric'])}" for i, d in enumerate(trajectory_data))
|
||||
svg += f'<polyline points="{points}" fill="none" stroke="#4ecdc4" stroke-width="2"/>'
|
||||
|
||||
# Data points
|
||||
for i, d in enumerate(trajectory_data):
|
||||
cx, cy = x_pos(i), y_pos(d["metric"])
|
||||
svg += f'<circle cx="{cx}" cy="{cy}" r="4" fill="#4ecdc4"/>'
|
||||
|
||||
# Title
|
||||
svg += f'<text x="{width/2}" y="24" fill="#eee" text-anchor="middle" font-size="14" font-weight="bold">Optimization Trajectory</text>'
|
||||
svg += f'<text x="{width/2}" y="{height-10}" fill="#888" text-anchor="middle" font-size="11">Experiment Run</text>'
|
||||
svg += '</svg>'
|
||||
return svg
|
||||
```
|
||||
|
||||
Embed the SVG output directly in the HTML report. Annotate significant jumps with brief labels.
|
||||
|
||||
## HTML Presentation Template
|
||||
|
||||
Use [templates/progress-presentation.html](../templates/progress-presentation.html) as a starting point. It provides:
|
||||
|
||||
- Clean, dark-themed styling suitable for research presentations
|
||||
- Responsive layout
|
||||
- Section scaffolding matching the recommended structure
|
||||
- Placeholder for the trajectory chart
|
||||
|
||||
Replace placeholder content with your actual research data. Add, remove, or rearrange sections as the research demands. The template is a scaffold, not a constraint.
|
||||
|
||||
### Claude Code
|
||||
|
||||
Generate the HTML, then show it to the human:
|
||||
|
||||
```bash
|
||||
open to_human/progress-001.html
|
||||
```
|
||||
|
||||
### OpenClaw
|
||||
|
||||
Generate a PDF version. Options:
|
||||
- Use Python `weasyprint` to convert HTML to PDF
|
||||
- Use `matplotlib` to generate plots directly as PDF
|
||||
- Create a simple markdown → PDF pipeline
|
||||
|
||||
Note the PDF path in HEARTBEAT.md so the human knows to look at it.
|
||||
|
||||
## Presentation Quality Tips
|
||||
|
||||
- **One insight per section** — don't overload
|
||||
- **Label axes and units** on all plots
|
||||
- **Use color consistently** — one color for improvements, another for baselines
|
||||
- **Include confidence intervals** or error bars where meaningful
|
||||
- **Show the trajectory early** — it's the hook that tells the reader "this is working"
|
||||
- **End with a clear next step** — the human should know what happens next without asking
|
||||
@@ -0,0 +1,218 @@
|
||||
# Skill Routing: When to Use Which Domain Skill
|
||||
|
||||
The autoresearch skill orchestrates — domain skills execute. This reference maps research activities to the skills library.
|
||||
|
||||
## Routing Principle
|
||||
|
||||
When you encounter a domain-specific task during research, search the skills library for the right tool. Read the SKILL.md of the relevant skill before starting — it contains workflows, common issues, and production-ready code examples.
|
||||
|
||||
## Complete Routing Map
|
||||
|
||||
### Data and Preprocessing
|
||||
|
||||
| Task | Skill | Location |
|
||||
|---|---|---|
|
||||
| Large-scale data processing | Ray Data | `05-data-processing/ray-data/` |
|
||||
| Data curation and filtering | NeMo Curator | `05-data-processing/nemo-curator/` |
|
||||
| Custom tokenizer training | HuggingFace Tokenizers | `02-tokenization/hf-tokenizers/` |
|
||||
| Subword tokenization | SentencePiece | `02-tokenization/sentencepiece/` |
|
||||
|
||||
### Model Architecture and Training
|
||||
|
||||
| Task | Skill | Location |
|
||||
|---|---|---|
|
||||
| Large-scale pretraining | Megatron-Core | `01-model-architecture/megatron-core/` |
|
||||
| Lightweight LLM training | LitGPT | `01-model-architecture/litgpt/` |
|
||||
| State-space models | Mamba | `01-model-architecture/mamba/` |
|
||||
| Linear attention models | RWKV | `01-model-architecture/rwkv/` |
|
||||
| Small-scale pretraining | NanoGPT | `01-model-architecture/nanogpt/` |
|
||||
|
||||
### Fine-tuning
|
||||
|
||||
| Task | Skill | Location |
|
||||
|---|---|---|
|
||||
| Multi-method fine-tuning | Axolotl | `03-fine-tuning/axolotl/` |
|
||||
| Template-based fine-tuning | LLaMA-Factory | `03-fine-tuning/llama-factory/` |
|
||||
| Fast LoRA fine-tuning | Unsloth | `03-fine-tuning/unsloth/` |
|
||||
| PyTorch-native fine-tuning | Torchtune | `03-fine-tuning/torchtune/` |
|
||||
|
||||
### Post-training (RL / Alignment)
|
||||
|
||||
| Task | Skill | Location |
|
||||
|---|---|---|
|
||||
| PPO, DPO, SFT pipelines | TRL | `06-post-training/trl/` |
|
||||
| Group Relative Policy Optimization | GRPO | `06-post-training/grpo-rl-training/` |
|
||||
| Scalable RLHF | OpenRLHF | `06-post-training/openrlhf/` |
|
||||
| Reference-free alignment | SimPO | `06-post-training/simpo/` |
|
||||
|
||||
### Interpretability
|
||||
|
||||
| Task | Skill | Location |
|
||||
|---|---|---|
|
||||
| Transformer circuit analysis | TransformerLens | `04-mechanistic-interpretability/transformerlens/` |
|
||||
| Sparse autoencoder training | SAELens | `04-mechanistic-interpretability/saelens/` |
|
||||
| Intervention experiments | NNsight | `04-mechanistic-interpretability/nnsight/` |
|
||||
| Causal tracing | Pyvene | `04-mechanistic-interpretability/pyvene/` |
|
||||
|
||||
### Distributed Training
|
||||
|
||||
| Task | Skill | Location |
|
||||
|---|---|---|
|
||||
| ZeRO optimization | DeepSpeed | `08-distributed-training/deepspeed/` |
|
||||
| Fully sharded data parallel | FSDP | `08-distributed-training/fsdp/` |
|
||||
| Multi-GPU abstraction | Accelerate | `08-distributed-training/accelerate/` |
|
||||
| Training framework | PyTorch Lightning | `08-distributed-training/pytorch-lightning/` |
|
||||
| Distributed data + training | Ray Train | `08-distributed-training/ray-train/` |
|
||||
|
||||
### Evaluation
|
||||
|
||||
| Task | Skill | Location |
|
||||
|---|---|---|
|
||||
| Standard LLM benchmarks | lm-evaluation-harness | `11-evaluation/lm-eval-harness/` |
|
||||
| NeMo-integrated evaluation | NeMo Evaluator | `11-evaluation/nemo-evaluator/` |
|
||||
| Custom eval tasks | Inspect AI | `11-evaluation/inspect-ai/` |
|
||||
|
||||
### Inference and Serving
|
||||
|
||||
| Task | Skill | Location |
|
||||
|---|---|---|
|
||||
| High-throughput serving | vLLM | `12-inference-serving/vllm/` |
|
||||
| NVIDIA-optimized inference | TensorRT-LLM | `12-inference-serving/tensorrt-llm/` |
|
||||
| CPU / edge inference | llama.cpp | `12-inference-serving/llama-cpp/` |
|
||||
| Structured generation serving | SGLang | `12-inference-serving/sglang/` |
|
||||
|
||||
### Experiment Tracking
|
||||
|
||||
| Task | Skill | Location |
|
||||
|---|---|---|
|
||||
| Full experiment tracking | Weights & Biases | `13-mlops/wandb/` |
|
||||
| Open-source tracking | MLflow | `13-mlops/mlflow/` |
|
||||
| Training visualization | TensorBoard | `13-mlops/tensorboard/` |
|
||||
|
||||
### Optimization Techniques
|
||||
|
||||
| Task | Skill | Location |
|
||||
|---|---|---|
|
||||
| Efficient attention | Flash Attention | `10-optimization/flash-attention/` |
|
||||
| 4/8-bit quantization | bitsandbytes | `10-optimization/bitsandbytes/` |
|
||||
| GPTQ quantization | GPTQ | `10-optimization/gptq/` |
|
||||
| AWQ quantization | AWQ | `10-optimization/awq/` |
|
||||
| GGUF format (llama.cpp) | GGUF | `10-optimization/gguf/` |
|
||||
| PyTorch-native quantization | Quanto | `10-optimization/quanto/` |
|
||||
|
||||
### Safety and Alignment
|
||||
|
||||
| Task | Skill | Location |
|
||||
|---|---|---|
|
||||
| Constitutional AI training | Constitutional AI | `07-safety-alignment/constitutional-ai/` |
|
||||
| Content safety classification | LlamaGuard | `07-safety-alignment/llamaguard/` |
|
||||
| Guardrail pipelines | NeMo Guardrails | `07-safety-alignment/nemo-guardrails/` |
|
||||
| Prompt injection detection | Prompt Guard | `07-safety-alignment/prompt-guard/` |
|
||||
|
||||
### Infrastructure
|
||||
|
||||
| Task | Skill | Location |
|
||||
|---|---|---|
|
||||
| Serverless GPU compute | Modal | `09-infrastructure/modal/` |
|
||||
| Multi-cloud orchestration | SkyPilot | `09-infrastructure/skypilot/` |
|
||||
| GPU cloud instances | Lambda Labs | `09-infrastructure/lambda-labs/` |
|
||||
|
||||
### Agents and RAG
|
||||
|
||||
| Task | Skill | Location |
|
||||
|---|---|---|
|
||||
| Agent pipelines | LangChain | `14-agents/langchain/` |
|
||||
| Knowledge retrieval agents | LlamaIndex | `14-agents/llamaindex/` |
|
||||
| Lightweight agents | Smolagents | `14-agents/smolagents/` |
|
||||
| Claude-based agents | Claude Agent SDK | `14-agents/claude-agent-sdk/` |
|
||||
| Vector store (local) | Chroma | `15-rag/chroma/` |
|
||||
| Vector similarity search | FAISS | `15-rag/faiss/` |
|
||||
| Text embeddings | Sentence Transformers | `15-rag/sentence-transformers/` |
|
||||
| Managed vector DB | Pinecone | `15-rag/pinecone/` |
|
||||
| Scalable vector DB | Milvus | `15-rag/milvus/` |
|
||||
|
||||
### Prompt Engineering and Structured Output
|
||||
|
||||
| Task | Skill | Location |
|
||||
|---|---|---|
|
||||
| Prompt optimization | DSPy | `16-prompt-engineering/dspy/` |
|
||||
| Structured LLM output | Instructor | `16-prompt-engineering/instructor/` |
|
||||
| Constrained generation | Guidance | `16-prompt-engineering/guidance/` |
|
||||
| Grammar-based generation | Outlines | `16-prompt-engineering/outlines/` |
|
||||
|
||||
### Multimodal
|
||||
|
||||
| Task | Skill | Location |
|
||||
|---|---|---|
|
||||
| Vision-language models | CLIP | `18-multimodal/clip/` |
|
||||
| Speech recognition | Whisper | `18-multimodal/whisper/` |
|
||||
| Visual instruction tuning | LLaVA | `18-multimodal/llava/` |
|
||||
| Vision-language (Qwen) | Qwen2-VL | `18-multimodal/qwen2-vl/` |
|
||||
| Vision-language (Mistral) | Pixtral | `18-multimodal/pixtral/` |
|
||||
| Visual understanding | Florence-2 | `18-multimodal/florence-2/` |
|
||||
| Document retrieval | ColPali | `18-multimodal/colpali/` |
|
||||
|
||||
### Observability
|
||||
|
||||
| Task | Skill | Location |
|
||||
|---|---|---|
|
||||
| LLM tracing and debugging | LangSmith | `17-observability/langsmith/` |
|
||||
| LLM observability platform | Phoenix | `17-observability/phoenix/` |
|
||||
|
||||
### Emerging Techniques
|
||||
|
||||
| Task | Skill | Location |
|
||||
|---|---|---|
|
||||
| Mixture of Experts training | MoE Training | `19-emerging-techniques/moe-training/` |
|
||||
| Combining trained models | Model Merging | `19-emerging-techniques/model-merging/` |
|
||||
| Extended context windows | Long Context | `19-emerging-techniques/long-context/` |
|
||||
| Faster inference via drafting | Speculative Decoding | `19-emerging-techniques/speculative-decoding/` |
|
||||
| Teacher-student compression | Knowledge Distillation | `19-emerging-techniques/knowledge-distillation/` |
|
||||
| Reducing model size | Model Pruning | `19-emerging-techniques/model-pruning/` |
|
||||
|
||||
### Research Output
|
||||
|
||||
| Task | Skill | Location |
|
||||
|---|---|---|
|
||||
| Generate research ideas | Research Ideation | `21-research-ideation/` |
|
||||
| Write publication-ready paper | ML Paper Writing | `20-ml-paper-writing/` |
|
||||
|
||||
## Common Research Workflows
|
||||
|
||||
### "I need to fine-tune a model and evaluate it"
|
||||
|
||||
1. Pick fine-tuning skill based on needs (Unsloth for speed, Axolotl for flexibility)
|
||||
2. Use lm-evaluation-harness for standard benchmarks
|
||||
3. Track with W&B or MLflow
|
||||
|
||||
### "I need to understand what the model learned"
|
||||
|
||||
1. Use TransformerLens for circuit-level analysis
|
||||
2. Train SAEs with SAELens for feature-level understanding
|
||||
3. Run interventions with NNsight or Pyvene
|
||||
|
||||
### "I need to do RL training"
|
||||
|
||||
1. Start with TRL for standard PPO/DPO
|
||||
2. Use GRPO skill for DeepSeek-R1 style training
|
||||
3. Scale with OpenRLHF if needed
|
||||
|
||||
### "I need to run experiments on cloud GPUs"
|
||||
|
||||
1. Modal for quick serverless runs
|
||||
2. SkyPilot for multi-cloud optimization
|
||||
3. Lambda Labs for dedicated instances
|
||||
|
||||
## Finding Skills
|
||||
|
||||
If you're not sure which skill to use:
|
||||
|
||||
```bash
|
||||
# Search by keyword in skill names
|
||||
ls */*/SKILL.md | head -20
|
||||
|
||||
# Search skill descriptions for a keyword
|
||||
grep -l "keyword" */*/SKILL.md
|
||||
```
|
||||
|
||||
Or search the repository's README.md which lists all skills with descriptions.
|
||||
@@ -0,0 +1,43 @@
|
||||
# Research Findings
|
||||
|
||||
## Research Question
|
||||
|
||||
<!-- What are we trying to discover? One clear sentence. -->
|
||||
|
||||
## Current Understanding
|
||||
|
||||
<!-- Updated after each outer loop cycle. What do we know so far?
|
||||
What patterns explain our results? What's the mechanism?
|
||||
This section should read like the core argument of a paper. -->
|
||||
|
||||
## Key Results
|
||||
|
||||
<!-- Significant experimental findings. Include metrics, comparisons, and
|
||||
brief interpretation. Link to experiment directories for full details. -->
|
||||
|
||||
## Patterns and Insights
|
||||
|
||||
<!-- What emerges across multiple experiments? What types of changes
|
||||
consistently work or fail? Why? -->
|
||||
|
||||
## Lessons and Constraints
|
||||
|
||||
<!-- Specific actionable learnings that should guide future experiments.
|
||||
Things you tried that didn't work and WHY, so you don't repeat them.
|
||||
Constraints you discovered about the problem space.
|
||||
|
||||
Examples:
|
||||
- Weight decay > 0.1 causes training instability at 125M param scale
|
||||
- SwiGLU and RoPE improvements stack because they're orthogonal (FFN vs positional)
|
||||
- Baseline only reproduces published numbers with batch_size=64, not 32
|
||||
- Sleep phases before memorization completion hurt — model needs memories to consolidate -->
|
||||
|
||||
## Open Questions
|
||||
|
||||
<!-- What remains unanswered? What would strengthen or challenge
|
||||
our current understanding? -->
|
||||
|
||||
## Optimization Trajectory
|
||||
|
||||
<!-- Summary of inner loop progress. How has the metric evolved?
|
||||
Note inflection points and what caused them. -->
|
||||
@@ -0,0 +1,306 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Research Progress</title>
|
||||
<style>
|
||||
* { margin: 0; padding: 0; box-sizing: border-box; }
|
||||
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', system-ui, sans-serif;
|
||||
background: #0d1117;
|
||||
color: #e6edf3;
|
||||
line-height: 1.6;
|
||||
padding: 2rem;
|
||||
max-width: 1100px;
|
||||
margin: 0 auto;
|
||||
}
|
||||
|
||||
header {
|
||||
text-align: center;
|
||||
padding: 3rem 0 2rem;
|
||||
border-bottom: 1px solid #21262d;
|
||||
margin-bottom: 2.5rem;
|
||||
}
|
||||
|
||||
header h1 {
|
||||
font-size: 2.2rem;
|
||||
font-weight: 700;
|
||||
color: #f0f6fc;
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
|
||||
.subtitle {
|
||||
font-size: 1.15rem;
|
||||
color: #8b949e;
|
||||
font-style: italic;
|
||||
max-width: 700px;
|
||||
margin: 0 auto 1rem;
|
||||
}
|
||||
|
||||
.meta {
|
||||
font-size: 0.85rem;
|
||||
color: #484f58;
|
||||
}
|
||||
|
||||
.meta span {
|
||||
display: inline-block;
|
||||
margin: 0 0.5rem;
|
||||
padding: 0.15rem 0.6rem;
|
||||
background: #161b22;
|
||||
border: 1px solid #21262d;
|
||||
border-radius: 12px;
|
||||
}
|
||||
|
||||
section {
|
||||
margin-bottom: 3rem;
|
||||
}
|
||||
|
||||
section h2 {
|
||||
font-size: 1.4rem;
|
||||
font-weight: 600;
|
||||
color: #f0f6fc;
|
||||
margin-bottom: 1rem;
|
||||
padding-bottom: 0.5rem;
|
||||
border-bottom: 1px solid #21262d;
|
||||
}
|
||||
|
||||
p, li { color: #c9d1d9; }
|
||||
|
||||
.card {
|
||||
background: #161b22;
|
||||
border: 1px solid #21262d;
|
||||
border-radius: 8px;
|
||||
padding: 1.5rem;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.card h3 {
|
||||
font-size: 1.05rem;
|
||||
color: #58a6ff;
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
|
||||
.result-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(220px, 1fr));
|
||||
gap: 1rem;
|
||||
margin-bottom: 1.5rem;
|
||||
}
|
||||
|
||||
.stat-card {
|
||||
background: #161b22;
|
||||
border: 1px solid #21262d;
|
||||
border-radius: 8px;
|
||||
padding: 1.2rem;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.stat-card .value {
|
||||
font-size: 2rem;
|
||||
font-weight: 700;
|
||||
color: #58a6ff;
|
||||
}
|
||||
|
||||
.stat-card .label {
|
||||
font-size: 0.8rem;
|
||||
color: #8b949e;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.05em;
|
||||
}
|
||||
|
||||
.stat-card.positive .value { color: #3fb950; }
|
||||
.stat-card.negative .value { color: #f85149; }
|
||||
|
||||
table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
margin: 1rem 0;
|
||||
}
|
||||
|
||||
th {
|
||||
text-align: left;
|
||||
padding: 0.6rem 1rem;
|
||||
background: #161b22;
|
||||
color: #8b949e;
|
||||
font-size: 0.8rem;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.05em;
|
||||
border-bottom: 1px solid #21262d;
|
||||
}
|
||||
|
||||
td {
|
||||
padding: 0.6rem 1rem;
|
||||
border-bottom: 1px solid #21262d;
|
||||
font-size: 0.95rem;
|
||||
}
|
||||
|
||||
.badge {
|
||||
display: inline-block;
|
||||
padding: 0.15rem 0.5rem;
|
||||
border-radius: 10px;
|
||||
font-size: 0.75rem;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.badge-supported { background: #0d2818; color: #3fb950; border: 1px solid #1b4332; }
|
||||
.badge-refuted { background: #2d1215; color: #f85149; border: 1px solid #4a1c20; }
|
||||
.badge-active { background: #0c2d6b; color: #58a6ff; border: 1px solid #1158c7; }
|
||||
.badge-pending { background: #1c1c1c; color: #8b949e; border: 1px solid #333; }
|
||||
|
||||
.chart-container {
|
||||
background: #161b22;
|
||||
border: 1px solid #21262d;
|
||||
border-radius: 8px;
|
||||
padding: 1.5rem;
|
||||
text-align: center;
|
||||
margin: 1rem 0;
|
||||
}
|
||||
|
||||
.next-steps {
|
||||
background: #0c2d6b22;
|
||||
border: 1px solid #1158c744;
|
||||
border-radius: 8px;
|
||||
padding: 1.5rem;
|
||||
}
|
||||
|
||||
.next-steps h3 { color: #58a6ff; margin-bottom: 0.5rem; }
|
||||
.next-steps ul { padding-left: 1.5rem; }
|
||||
.next-steps li { margin-bottom: 0.3rem; }
|
||||
|
||||
footer {
|
||||
text-align: center;
|
||||
padding: 2rem 0;
|
||||
color: #484f58;
|
||||
font-size: 0.8rem;
|
||||
border-top: 1px solid #21262d;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
|
||||
<!--
|
||||
AGENT INSTRUCTIONS:
|
||||
This is a starting point. Fill in, rearrange, add, or remove sections
|
||||
based on what's compelling from your current research. The goal is a
|
||||
research story, not a status dashboard.
|
||||
|
||||
Replace {{PLACEHOLDERS}} with actual content.
|
||||
Embed SVG charts inline (see progress-reporting.md for the trajectory plot function).
|
||||
Add additional sections as needed.
|
||||
-->
|
||||
|
||||
<header>
|
||||
<h1>{{PROJECT_TITLE}}</h1>
|
||||
<p class="subtitle">{{RESEARCH_QUESTION}}</p>
|
||||
<p class="meta">
|
||||
<span>{{DATE}}</span>
|
||||
<span>{{N_EXPERIMENTS}} experiments</span>
|
||||
<span>Status: {{STATUS}}</span>
|
||||
</p>
|
||||
</header>
|
||||
|
||||
<!-- Summary stats -->
|
||||
<section>
|
||||
<div class="result-grid">
|
||||
<div class="stat-card positive">
|
||||
<div class="value">{{BEST_METRIC}}</div>
|
||||
<div class="label">Best Metric</div>
|
||||
</div>
|
||||
<div class="stat-card">
|
||||
<div class="value">{{BASELINE_METRIC}}</div>
|
||||
<div class="label">Baseline</div>
|
||||
</div>
|
||||
<div class="stat-card positive">
|
||||
<div class="value">{{IMPROVEMENT}}</div>
|
||||
<div class="label">Improvement</div>
|
||||
</div>
|
||||
<div class="stat-card">
|
||||
<div class="value">{{N_HYPOTHESES}}</div>
|
||||
<div class="label">Hypotheses Tested</div>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- Background and motivation -->
|
||||
<section id="background">
|
||||
<h2>Background & Motivation</h2>
|
||||
<div class="card">
|
||||
<!-- Why does this research matter? What gap are we addressing? -->
|
||||
<p>{{BACKGROUND_TEXT}}</p>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- Optimization trajectory - THE key visual -->
|
||||
<section id="trajectory">
|
||||
<h2>Optimization Trajectory</h2>
|
||||
<div class="chart-container">
|
||||
<!-- Embed SVG chart here. See references/progress-reporting.md
|
||||
for the generate_trajectory_svg() function. -->
|
||||
{{TRAJECTORY_SVG}}
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- Key findings -->
|
||||
<section id="findings">
|
||||
<h2>Key Findings</h2>
|
||||
<!-- Add cards for each significant finding -->
|
||||
<div class="card">
|
||||
<h3>{{FINDING_1_TITLE}}</h3>
|
||||
<p>{{FINDING_1_DESCRIPTION}}</p>
|
||||
<!-- Include inline plots, tables, or metrics as needed -->
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- What was tried -->
|
||||
<section id="experiments">
|
||||
<h2>What We Tried</h2>
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Hypothesis</th>
|
||||
<th>Change</th>
|
||||
<th>Result</th>
|
||||
<th>Status</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<!-- Add rows for notable experiments -->
|
||||
<tr>
|
||||
<td>{{H_ID}}</td>
|
||||
<td>{{CHANGE_SUMMARY}}</td>
|
||||
<td>{{METRIC_DELTA}}</td>
|
||||
<td><span class="badge badge-supported">{{STATUS}}</span></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</section>
|
||||
|
||||
<!-- Current understanding -->
|
||||
<section id="understanding">
|
||||
<h2>Current Understanding</h2>
|
||||
<div class="card">
|
||||
<!-- The narrative from findings.md, but presented compellingly -->
|
||||
<p>{{CURRENT_UNDERSTANDING}}</p>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- Next steps -->
|
||||
<section id="next">
|
||||
<h2>Next Steps</h2>
|
||||
<div class="next-steps">
|
||||
<ul>
|
||||
<li>{{NEXT_STEP_1}}</li>
|
||||
<li>{{NEXT_STEP_2}}</li>
|
||||
<li>{{NEXT_STEP_3}}</li>
|
||||
</ul>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<footer>
|
||||
Generated by Autoresearch | {{DATE}}
|
||||
</footer>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
@@ -0,0 +1,40 @@
|
||||
# Research Log
|
||||
|
||||
Chronological record of research decisions and actions. Append-only.
|
||||
|
||||
| # | Date | Type | Summary |
|
||||
|---|------|------|---------|
|
||||
| | | | |
|
||||
|
||||
<!-- Entry types:
|
||||
bootstrap — initial scoping, literature search, hypothesis formation
|
||||
inner-loop — experiment run and result
|
||||
outer-loop — synthesis, reflection, direction decision
|
||||
pivot — change in research direction
|
||||
report — progress presentation generated
|
||||
conclude — decision to finalize and write paper
|
||||
|
||||
Example entries:
|
||||
| 1 | 2026-03-15 | bootstrap | Searched Semantic Scholar + arXiv for efficient transformer architectures. Found 8 relevant papers. Gap: no systematic comparison of GLU variants on small models. Formed 3 hypotheses. Baseline: NanoGPT 5-min run, val_loss=4.82. |
|
||||
| 2 | 2026-03-15 | inner-loop | H1 run_001: swapped ReLU for SwiGLU in FFN. 5-min training run. val_loss=4.61 (baseline 4.82, delta -0.21). Kept. |
|
||||
| 3 | 2026-03-15 | inner-loop | H1 run_002: increased FFN hidden dim from 4x to 5.3x to match SwiGLU param count. val_loss=4.58 (-0.03 vs run_001). Marginal — SwiGLU benefit mostly from gating, not extra params. |
|
||||
| 4 | 2026-03-15 | inner-loop | H1 run_003: tried GEGLU instead of SwiGLU. val_loss=4.63. Slightly worse than SwiGLU. SwiGLU wins for this scale. |
|
||||
| 5 | 2026-03-15 | inner-loop | H2 run_004: replaced learned positional embeddings with RoPE. val_loss=4.55 (-0.06 vs SwiGLU baseline). Promising — stacks with SwiGLU. |
|
||||
| 6 | 2026-03-15 | inner-loop | H2 run_005: RoPE + SwiGLU combined. val_loss=4.41 (-0.41 vs original baseline). Best so far. |
|
||||
| 7 | 2026-03-16 | outer-loop | Reviewed 5 runs. Pattern: gating mechanisms (SwiGLU) and rotary embeddings (RoPE) give independent gains that stack. Combined improvement ~9%. But WHY do they stack? Hypothesis: they operate on orthogonal aspects (FFN expressiveness vs positional encoding). Direction: DEEPEN — test if adding RMSNorm also stacks independently. |
|
||||
| 8 | 2026-03-16 | inner-loop | H3 run_006: replaced LayerNorm with RMSNorm. val_loss=4.39 (-0.02). Small gain. Stacks but diminishing returns on normalization. |
|
||||
| 9 | 2026-03-17 | outer-loop | 8 runs complete. Optimization plateau around val_loss=4.38. The easy architectural wins (SwiGLU, RoPE) are captured. Searched literature on training dynamics — found papers on warmup schedules at small scale. Direction: BROADEN — shift from architecture to training recipe. |
|
||||
| 10 | 2026-03-17 | report | Generated progress-001.html with trajectory plot showing 9% improvement from architectural changes. |
|
||||
|
||||
Example entries (discovery-type research — understanding grokking):
|
||||
| 1 | 2026-03-20 | bootstrap | Searched literature on grokking and delayed generalization. Found Nanda et al. progress measures, Grokfast spectral filtering. Gap: no connection to memory consolidation theory from neuroscience. 3 hypotheses formed. |
|
||||
| 2 | 2026-03-20 | inner-loop | H1 run_001: trained modular addition transformer to memorization (100% train acc, 0% test). Steps to memorize: 1200. Baseline established. |
|
||||
| 3 | 2026-03-20 | inner-loop | H1 run_002: continued training with standard weight decay. Grokking at step 48000. Measured progress measure throughout — sharp transition at step 44000. |
|
||||
| 4 | 2026-03-20 | inner-loop | H1 run_003: inserted "sleep phase" at step 20000 (elevated weight decay + oscillatory LR for 500 steps). Grokking now at step 31000. 35% acceleration. |
|
||||
| 5 | 2026-03-20 | inner-loop | H1 run_004: sleep phase at step 10000. Grokking at step 27000. Earlier sleep = earlier grokking. |
|
||||
| 6 | 2026-03-20 | inner-loop | H1 run_005: sleep phase at step 5000 (before full memorization). Grokking at step 38000. Too early hurts — model hadn't memorized enough for consolidation to work. |
|
||||
| 7 | 2026-03-21 | outer-loop | Reviewed 5 runs. Clear pattern: sleep phases accelerate grokking but only AFTER memorization is complete. This matches memory consolidation theory exactly — you need memories formed before consolidation can reorganize them. Searched for neural slow-wave sleep literature. The weight decay + oscillatory LR during sleep phases mimics synaptic downscaling. Direction: DEEPEN — sweep sleep timing relative to memorization completion. |
|
||||
| 8 | 2026-03-21 | inner-loop | H1.1 run_006-010: swept sleep insertion at 80%, 100%, 120%, 150%, 200% of memorization step. Sweet spot at 110-120%. Consistent across 3 seeds. |
|
||||
| 9 | 2026-03-22 | outer-loop | 10 runs complete. The story is clear: neural networks "dream to learn" just like brains — consolidation after encoding, not during. Grokfast achieves similar acceleration through a different mechanism (gradient spectral filtering). Next: compare gradient spectra during our sleep phases vs Grokfast filtering to see if they converge on the same signal. Direction: BROADEN. |
|
||||
| 10 | 2026-03-22 | report | Generated progress-001.html with sleep timing vs grokking step plot. Key visual: sweet spot curve mirrors neuroscience memory consolidation window. |
|
||||
-->
|
||||
@@ -0,0 +1,57 @@
|
||||
# Research State — Central Project Tracking
|
||||
# Copy this template to your project root and fill in as you go.
|
||||
# Updated by the agent after each experiment and reflection.
|
||||
|
||||
project:
|
||||
title: ""
|
||||
question: "" # The core research question
|
||||
status: active # active | paused | concluded
|
||||
started: "" # ISO date
|
||||
domain: "" # e.g., "mechanistic interpretability", "RL training"
|
||||
|
||||
literature:
|
||||
key_papers: []
|
||||
# - id: "liu2025superposition"
|
||||
# title: "Superposition Yields Robust Neural Scaling"
|
||||
# authors: "Liu et al."
|
||||
# year: 2025
|
||||
# relevance: "Proves ETF structure in LM heads"
|
||||
open_problems: [] # Gaps identified from literature
|
||||
evidence_gaps: [] # What's missing in the field
|
||||
|
||||
hypotheses:
|
||||
# List of all hypotheses, active and completed
|
||||
# - id: H1
|
||||
# statement: "Testable claim with clear prediction"
|
||||
# status: pending # pending | active | supported | refuted | inconclusive
|
||||
# motivation: "Why this is worth testing"
|
||||
# parent: null # null for root, parent ID (e.g., H1) for sub-hypotheses
|
||||
# priority: medium # high | medium | low
|
||||
|
||||
experiments:
|
||||
proxy_metric: "" # What we're optimizing and how to compute it
|
||||
baseline_value: null # Starting point
|
||||
best_value: null # Best achieved so far
|
||||
total_runs: 0
|
||||
trajectory: []
|
||||
# - run_id: "run_001"
|
||||
# hypothesis: "H1"
|
||||
# metric_value: null
|
||||
# delta: null # Change from baseline
|
||||
# wall_time_min: null
|
||||
# change_summary: ""
|
||||
# timestamp: ""
|
||||
|
||||
outer_loop:
|
||||
cycle: 0 # How many outer loop reflections so far
|
||||
last_direction: null # deepen | broaden | pivot | conclude
|
||||
last_reflection: "" # Brief summary of last reflection decision
|
||||
|
||||
workspace:
|
||||
# Track key resource locations
|
||||
findings: "findings.md"
|
||||
log: "research-log.md"
|
||||
literature_dir: "literature/"
|
||||
experiments_dir: "experiments/"
|
||||
to_human_dir: "to_human/"
|
||||
paper_dir: "paper/"
|
||||
@@ -0,0 +1,5 @@
|
||||
# Skills Coming Soon
|
||||
|
||||
This directory will contain high-quality AI research skills for model architecture.
|
||||
|
||||
See [CONTRIBUTING.md](../CONTRIBUTING.md) for how to contribute.
|
||||
@@ -0,0 +1,469 @@
|
||||
---
|
||||
name: implementing-llms-litgpt
|
||||
description: Implements and trains LLMs using Lightning AI's LitGPT with 20+ pretrained architectures (Llama, Gemma, Phi, Qwen, Mistral). Use when need clean model implementations, educational understanding of architectures, or production fine-tuning with LoRA/QLoRA. Single-file implementations, no abstraction layers.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Model Architecture, LitGPT, Lightning AI, LLM Implementation, LoRA, QLoRA, Fine-Tuning, Llama, Gemma, Phi, Mistral, Educational]
|
||||
dependencies: [litgpt, torch, transformers]
|
||||
---
|
||||
|
||||
# LitGPT - Clean LLM Implementations
|
||||
|
||||
## Quick start
|
||||
|
||||
LitGPT provides 20+ pretrained LLM implementations with clean, readable code and production-ready training workflows.
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install 'litgpt[extra]'
|
||||
```
|
||||
|
||||
**Load and use any model**:
|
||||
```python
|
||||
from litgpt import LLM
|
||||
|
||||
# Load pretrained model
|
||||
llm = LLM.load("microsoft/phi-2")
|
||||
|
||||
# Generate text
|
||||
result = llm.generate(
|
||||
"What is the capital of France?",
|
||||
max_new_tokens=50,
|
||||
temperature=0.7
|
||||
)
|
||||
print(result)
|
||||
```
|
||||
|
||||
**List available models**:
|
||||
```bash
|
||||
litgpt download list
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Fine-tune on custom dataset
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
Fine-Tuning Setup:
|
||||
- [ ] Step 1: Download pretrained model
|
||||
- [ ] Step 2: Prepare dataset
|
||||
- [ ] Step 3: Configure training
|
||||
- [ ] Step 4: Run fine-tuning
|
||||
```
|
||||
|
||||
**Step 1: Download pretrained model**
|
||||
|
||||
```bash
|
||||
# Download Llama 3 8B
|
||||
litgpt download meta-llama/Meta-Llama-3-8B
|
||||
|
||||
# Download Phi-2 (smaller, faster)
|
||||
litgpt download microsoft/phi-2
|
||||
|
||||
# Download Gemma 2B
|
||||
litgpt download google/gemma-2b
|
||||
```
|
||||
|
||||
Models are saved to `checkpoints/` directory.
|
||||
|
||||
**Step 2: Prepare dataset**
|
||||
|
||||
LitGPT supports multiple formats:
|
||||
|
||||
**Alpaca format** (instruction-response):
|
||||
```json
|
||||
[
|
||||
{
|
||||
"instruction": "What is the capital of France?",
|
||||
"input": "",
|
||||
"output": "The capital of France is Paris."
|
||||
},
|
||||
{
|
||||
"instruction": "Translate to Spanish: Hello, how are you?",
|
||||
"input": "",
|
||||
"output": "Hola, ¿cómo estás?"
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
Save as `data/my_dataset.json`.
|
||||
|
||||
**Step 3: Configure training**
|
||||
|
||||
```bash
|
||||
# Full fine-tuning (requires 40GB+ GPU for 7B models)
|
||||
litgpt finetune \
|
||||
meta-llama/Meta-Llama-3-8B \
|
||||
--data JSON \
|
||||
--data.json_path data/my_dataset.json \
|
||||
--train.max_steps 1000 \
|
||||
--train.learning_rate 2e-5 \
|
||||
--train.micro_batch_size 1 \
|
||||
--train.global_batch_size 16
|
||||
|
||||
# LoRA fine-tuning (efficient, 16GB GPU)
|
||||
litgpt finetune_lora \
|
||||
microsoft/phi-2 \
|
||||
--data JSON \
|
||||
--data.json_path data/my_dataset.json \
|
||||
--lora_r 16 \
|
||||
--lora_alpha 32 \
|
||||
--lora_dropout 0.05 \
|
||||
--train.max_steps 1000 \
|
||||
--train.learning_rate 1e-4
|
||||
```
|
||||
|
||||
**Step 4: Run fine-tuning**
|
||||
|
||||
Training saves checkpoints to `out/finetune/` automatically.
|
||||
|
||||
Monitor training:
|
||||
```bash
|
||||
# View logs
|
||||
tail -f out/finetune/logs.txt
|
||||
|
||||
# TensorBoard (if using --train.logger_name tensorboard)
|
||||
tensorboard --logdir out/finetune/lightning_logs
|
||||
```
|
||||
|
||||
### Workflow 2: LoRA fine-tuning on single GPU
|
||||
|
||||
Most memory-efficient option.
|
||||
|
||||
```
|
||||
LoRA Training:
|
||||
- [ ] Step 1: Choose base model
|
||||
- [ ] Step 2: Configure LoRA parameters
|
||||
- [ ] Step 3: Train with LoRA
|
||||
- [ ] Step 4: Merge LoRA weights (optional)
|
||||
```
|
||||
|
||||
**Step 1: Choose base model**
|
||||
|
||||
For limited GPU memory (12-16GB):
|
||||
- **Phi-2** (2.7B) - Best quality/size tradeoff
|
||||
- **Llama 3 1B** - Smallest, fastest
|
||||
- **Gemma 2B** - Good reasoning
|
||||
|
||||
**Step 2: Configure LoRA parameters**
|
||||
|
||||
```bash
|
||||
litgpt finetune_lora \
|
||||
microsoft/phi-2 \
|
||||
--data JSON \
|
||||
--data.json_path data/my_dataset.json \
|
||||
--lora_r 16 \ # LoRA rank (8-64, higher=more capacity)
|
||||
--lora_alpha 32 \ # LoRA scaling (typically 2×r)
|
||||
--lora_dropout 0.05 \ # Prevent overfitting
|
||||
--lora_query true \ # Apply LoRA to query projection
|
||||
--lora_key false \ # Usually not needed
|
||||
--lora_value true \ # Apply LoRA to value projection
|
||||
--lora_projection true \ # Apply LoRA to output projection
|
||||
--lora_mlp false \ # Usually not needed
|
||||
--lora_head false # Usually not needed
|
||||
```
|
||||
|
||||
LoRA rank guide:
|
||||
- `r=8`: Lightweight, 2-4MB adapters
|
||||
- `r=16`: Standard, good quality
|
||||
- `r=32`: High capacity, use for complex tasks
|
||||
- `r=64`: Maximum quality, 4× larger adapters
|
||||
|
||||
**Step 3: Train with LoRA**
|
||||
|
||||
```bash
|
||||
litgpt finetune_lora \
|
||||
microsoft/phi-2 \
|
||||
--data JSON \
|
||||
--data.json_path data/my_dataset.json \
|
||||
--lora_r 16 \
|
||||
--train.epochs 3 \
|
||||
--train.learning_rate 1e-4 \
|
||||
--train.micro_batch_size 4 \
|
||||
--train.global_batch_size 32 \
|
||||
--out_dir out/phi2-lora
|
||||
|
||||
# Memory usage: ~8-12GB for Phi-2 with LoRA
|
||||
```
|
||||
|
||||
**Step 4: Merge LoRA weights** (optional)
|
||||
|
||||
Merge LoRA adapters into base model for deployment:
|
||||
|
||||
```bash
|
||||
litgpt merge_lora \
|
||||
out/phi2-lora/final \
|
||||
--out_dir out/phi2-merged
|
||||
```
|
||||
|
||||
Now use merged model:
|
||||
```python
|
||||
from litgpt import LLM
|
||||
llm = LLM.load("out/phi2-merged")
|
||||
```
|
||||
|
||||
### Workflow 3: Pretrain from scratch
|
||||
|
||||
Train new model on your domain data.
|
||||
|
||||
```
|
||||
Pretraining:
|
||||
- [ ] Step 1: Prepare pretraining dataset
|
||||
- [ ] Step 2: Configure model architecture
|
||||
- [ ] Step 3: Set up multi-GPU training
|
||||
- [ ] Step 4: Launch pretraining
|
||||
```
|
||||
|
||||
**Step 1: Prepare pretraining dataset**
|
||||
|
||||
LitGPT expects tokenized data. Use `prepare_dataset.py`:
|
||||
|
||||
```bash
|
||||
python scripts/prepare_dataset.py \
|
||||
--source_path data/my_corpus.txt \
|
||||
--checkpoint_dir checkpoints/tokenizer \
|
||||
--destination_path data/pretrain \
|
||||
--split train,val
|
||||
```
|
||||
|
||||
**Step 2: Configure model architecture**
|
||||
|
||||
Edit config file or use existing:
|
||||
|
||||
```python
|
||||
# config/pythia-160m.yaml
|
||||
model_name: pythia-160m
|
||||
block_size: 2048
|
||||
vocab_size: 50304
|
||||
n_layer: 12
|
||||
n_head: 12
|
||||
n_embd: 768
|
||||
rotary_percentage: 0.25
|
||||
parallel_residual: true
|
||||
bias: true
|
||||
```
|
||||
|
||||
**Step 3: Set up multi-GPU training**
|
||||
|
||||
```bash
|
||||
# Single GPU
|
||||
litgpt pretrain \
|
||||
--config config/pythia-160m.yaml \
|
||||
--data.data_dir data/pretrain \
|
||||
--train.max_tokens 10_000_000_000
|
||||
|
||||
# Multi-GPU with FSDP
|
||||
litgpt pretrain \
|
||||
--config config/pythia-1b.yaml \
|
||||
--data.data_dir data/pretrain \
|
||||
--devices 8 \
|
||||
--train.max_tokens 100_000_000_000
|
||||
```
|
||||
|
||||
**Step 4: Launch pretraining**
|
||||
|
||||
For large-scale pretraining on cluster:
|
||||
|
||||
```bash
|
||||
# Using SLURM
|
||||
sbatch --nodes=8 --gpus-per-node=8 \
|
||||
pretrain_script.sh
|
||||
|
||||
# pretrain_script.sh content:
|
||||
litgpt pretrain \
|
||||
--config config/pythia-1b.yaml \
|
||||
--data.data_dir /shared/data/pretrain \
|
||||
--devices 8 \
|
||||
--num_nodes 8 \
|
||||
--train.global_batch_size 512 \
|
||||
--train.max_tokens 300_000_000_000
|
||||
```
|
||||
|
||||
### Workflow 4: Convert and deploy model
|
||||
|
||||
Export LitGPT models for production.
|
||||
|
||||
```
|
||||
Model Deployment:
|
||||
- [ ] Step 1: Test inference locally
|
||||
- [ ] Step 2: Quantize model (optional)
|
||||
- [ ] Step 3: Convert to GGUF (for llama.cpp)
|
||||
- [ ] Step 4: Deploy with API
|
||||
```
|
||||
|
||||
**Step 1: Test inference locally**
|
||||
|
||||
```python
|
||||
from litgpt import LLM
|
||||
|
||||
llm = LLM.load("out/phi2-lora/final")
|
||||
|
||||
# Single generation
|
||||
print(llm.generate("What is machine learning?"))
|
||||
|
||||
# Streaming
|
||||
for token in llm.generate("Explain quantum computing", stream=True):
|
||||
print(token, end="", flush=True)
|
||||
|
||||
# Batch inference
|
||||
prompts = ["Hello", "Goodbye", "Thank you"]
|
||||
results = [llm.generate(p) for p in prompts]
|
||||
```
|
||||
|
||||
**Step 2: Quantize model** (optional)
|
||||
|
||||
Reduce model size with minimal quality loss:
|
||||
|
||||
```bash
|
||||
# 8-bit quantization (50% size reduction)
|
||||
litgpt convert_lit_checkpoint \
|
||||
out/phi2-lora/final \
|
||||
--dtype bfloat16 \
|
||||
--quantize bnb.nf4
|
||||
|
||||
# 4-bit quantization (75% size reduction)
|
||||
litgpt convert_lit_checkpoint \
|
||||
out/phi2-lora/final \
|
||||
--quantize bnb.nf4-dq # Double quantization
|
||||
```
|
||||
|
||||
**Step 3: Convert to GGUF** (for llama.cpp)
|
||||
|
||||
```bash
|
||||
python scripts/convert_lit_checkpoint.py \
|
||||
--checkpoint_path out/phi2-lora/final \
|
||||
--output_path models/phi2.gguf \
|
||||
--model_name microsoft/phi-2
|
||||
```
|
||||
|
||||
**Step 4: Deploy with API**
|
||||
|
||||
```python
|
||||
from fastapi import FastAPI
|
||||
from litgpt import LLM
|
||||
|
||||
app = FastAPI()
|
||||
llm = LLM.load("out/phi2-lora/final")
|
||||
|
||||
@app.post("/generate")
|
||||
def generate(prompt: str, max_tokens: int = 100):
|
||||
result = llm.generate(
|
||||
prompt,
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=0.7
|
||||
)
|
||||
return {"response": result}
|
||||
|
||||
# Run: uvicorn api:app --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use LitGPT when:**
|
||||
- Want to understand LLM architectures (clean, readable code)
|
||||
- Need production-ready training recipes
|
||||
- Educational purposes or research
|
||||
- Prototyping new model ideas
|
||||
- Lightning ecosystem user
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **Axolotl/TRL**: More fine-tuning features, YAML configs
|
||||
- **Megatron-Core**: Maximum performance for >70B models
|
||||
- **HuggingFace Transformers**: Broadest model support
|
||||
- **vLLM**: Inference-only (no training)
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: Out of memory during fine-tuning**
|
||||
|
||||
Use LoRA instead of full fine-tuning:
|
||||
```bash
|
||||
# Instead of litgpt finetune (requires 40GB+)
|
||||
litgpt finetune_lora # Only needs 12-16GB
|
||||
```
|
||||
|
||||
Or enable gradient checkpointing:
|
||||
```bash
|
||||
litgpt finetune_lora \
|
||||
... \
|
||||
--train.gradient_accumulation_iters 4 # Accumulate gradients
|
||||
```
|
||||
|
||||
**Issue: Training too slow**
|
||||
|
||||
Enable Flash Attention (built-in, automatic on compatible hardware):
|
||||
```python
|
||||
# Already enabled by default on Ampere+ GPUs (A100, RTX 30/40 series)
|
||||
# No configuration needed
|
||||
```
|
||||
|
||||
Use smaller micro-batch and accumulate:
|
||||
```bash
|
||||
--train.micro_batch_size 1 \
|
||||
--train.global_batch_size 32 \
|
||||
--train.gradient_accumulation_iters 32 # Effective batch=32
|
||||
```
|
||||
|
||||
**Issue: Model not loading**
|
||||
|
||||
Check model name:
|
||||
```bash
|
||||
# List all available models
|
||||
litgpt download list
|
||||
|
||||
# Download if not exists
|
||||
litgpt download meta-llama/Meta-Llama-3-8B
|
||||
```
|
||||
|
||||
Verify checkpoints directory:
|
||||
```bash
|
||||
ls checkpoints/
|
||||
# Should see: meta-llama/Meta-Llama-3-8B/
|
||||
```
|
||||
|
||||
**Issue: LoRA adapters too large**
|
||||
|
||||
Reduce LoRA rank:
|
||||
```bash
|
||||
--lora_r 8 # Instead of 16 or 32
|
||||
```
|
||||
|
||||
Apply LoRA to fewer layers:
|
||||
```bash
|
||||
--lora_query true \
|
||||
--lora_value true \
|
||||
--lora_projection false \ # Disable this
|
||||
--lora_mlp false # And this
|
||||
```
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**Supported architectures**: See [references/supported-models.md](references/supported-models.md) for complete list of 20+ model families with sizes and capabilities.
|
||||
|
||||
**Training recipes**: See [references/training-recipes.md](references/training-recipes.md) for proven hyperparameter configurations for pretraining and fine-tuning.
|
||||
|
||||
**FSDP configuration**: See [references/distributed-training.md](references/distributed-training.md) for multi-GPU training with Fully Sharded Data Parallel.
|
||||
|
||||
**Custom architectures**: See [references/custom-models.md](references/custom-models.md) for implementing new model architectures in LitGPT style.
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **GPU**: NVIDIA (CUDA 11.8+), AMD (ROCm), Apple Silicon (MPS)
|
||||
- **Memory**:
|
||||
- Inference (Phi-2): 6GB
|
||||
- LoRA fine-tuning (7B): 16GB
|
||||
- Full fine-tuning (7B): 40GB+
|
||||
- Pretraining (1B): 24GB
|
||||
- **Storage**: 5-50GB per model (depending on size)
|
||||
|
||||
## Resources
|
||||
|
||||
- GitHub: https://github.com/Lightning-AI/litgpt
|
||||
- Docs: https://lightning.ai/docs/litgpt
|
||||
- Tutorials: https://lightning.ai/docs/litgpt/tutorials
|
||||
- Model zoo: 20+ pretrained architectures (Llama, Gemma, Phi, Qwen, Mistral, Mixtral, Falcon, etc.)
|
||||
|
||||
|
||||
@@ -0,0 +1,568 @@
|
||||
# Custom Models
|
||||
|
||||
Guide to implementing custom model architectures in LitGPT.
|
||||
|
||||
## Overview
|
||||
|
||||
LitGPT's clean, single-file implementations make it easy to create custom architectures. You can extend the base `GPT` class or create entirely new models.
|
||||
|
||||
**Use cases**:
|
||||
- Implementing new research architectures
|
||||
- Adapting models for specific domains
|
||||
- Experimenting with attention mechanisms
|
||||
- Adding custom layers or components
|
||||
|
||||
## Key Files and Classes
|
||||
|
||||
### Core Architecture (`litgpt/model.py`)
|
||||
|
||||
**Main classes**:
|
||||
- `GPT`: Top-level model class
|
||||
- `Block`: Transformer block (attention + MLP)
|
||||
- `CausalSelfAttention`: Attention mechanism
|
||||
- `MLP`: Feed-forward network
|
||||
- `RMSNorm` / `LayerNorm`: Normalization layers
|
||||
|
||||
**Configuration** (`litgpt/config.py`):
|
||||
- `Config`: Base configuration dataclass
|
||||
- Model-specific configs: `LlamaConfig`, `MistralConfig`, `PhiConfig`, etc.
|
||||
|
||||
## Custom Architecture Workflow
|
||||
|
||||
### Step 1: Define Configuration
|
||||
|
||||
Create a `Config` dataclass with your model's hyperparameters:
|
||||
|
||||
```python
|
||||
from dataclasses import dataclass
|
||||
from litgpt.config import Config
|
||||
|
||||
@dataclass
|
||||
class MyModelConfig(Config):
|
||||
"""Configuration for my custom model."""
|
||||
# Standard parameters
|
||||
name: str = "my-model-7b"
|
||||
block_size: int = 4096
|
||||
vocab_size: int = 32000
|
||||
n_layer: int = 32
|
||||
n_head: int = 32
|
||||
n_embd: int = 4096
|
||||
|
||||
# Custom parameters
|
||||
custom_param: float = 0.1
|
||||
use_custom_attention: bool = True
|
||||
|
||||
# Optional: override defaults
|
||||
rope_base: int = 10000
|
||||
intermediate_size: int = 11008
|
||||
```
|
||||
|
||||
### Step 2: Implement Custom Components
|
||||
|
||||
#### Option A: Custom Attention
|
||||
|
||||
```python
|
||||
from litgpt.model import CausalSelfAttention
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class CustomAttention(CausalSelfAttention):
|
||||
"""Custom attention mechanism."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
# Add custom components
|
||||
self.custom_proj = nn.Linear(config.n_embd, config.n_embd)
|
||||
self.custom_param = config.custom_param
|
||||
|
||||
def forward(self, x, mask=None, input_pos=None):
|
||||
B, T, C = x.size()
|
||||
|
||||
# Standard Q, K, V projections
|
||||
q = self.attn(x)
|
||||
k = self.attn(x)
|
||||
v = self.attn(x)
|
||||
|
||||
# Custom modification
|
||||
q = q + self.custom_proj(x) * self.custom_param
|
||||
|
||||
# Rest of attention computation
|
||||
q = q.view(B, T, self.n_head, self.head_size)
|
||||
k = k.view(B, T, self.n_query_groups, self.head_size)
|
||||
v = v.view(B, T, self.n_query_groups, self.head_size)
|
||||
|
||||
# Scaled dot-product attention
|
||||
y = self.scaled_dot_product_attention(q, k, v, mask=mask)
|
||||
|
||||
y = y.reshape(B, T, C)
|
||||
return self.proj(y)
|
||||
```
|
||||
|
||||
#### Option B: Custom MLP
|
||||
|
||||
```python
|
||||
from litgpt.model import MLP
|
||||
|
||||
class CustomMLP(MLP):
|
||||
"""Custom feed-forward network."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
# Add custom layers
|
||||
self.custom_layer = nn.Linear(config.intermediate_size, config.intermediate_size)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc_1(x)
|
||||
x = self.act(x)
|
||||
x = self.custom_layer(x) # Custom modification
|
||||
x = self.fc_2(x)
|
||||
return x
|
||||
```
|
||||
|
||||
#### Option C: Custom Block
|
||||
|
||||
```python
|
||||
from litgpt.model import Block
|
||||
|
||||
class CustomBlock(Block):
|
||||
"""Custom transformer block."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
# Replace attention or MLP
|
||||
self.attn = CustomAttention(config)
|
||||
# Or: self.mlp = CustomMLP(config)
|
||||
|
||||
# Add custom components
|
||||
self.custom_norm = nn.LayerNorm(config.n_embd)
|
||||
|
||||
def forward(self, x, input_pos=None, mask=None):
|
||||
# Custom forward pass
|
||||
h = self.norm_1(x)
|
||||
h = self.attn(h, mask=mask, input_pos=input_pos)
|
||||
x = x + h
|
||||
|
||||
# Custom normalization
|
||||
x = x + self.custom_norm(x)
|
||||
|
||||
x = x + self.mlp(self.norm_2(x))
|
||||
return x
|
||||
```
|
||||
|
||||
### Step 3: Create Custom GPT Model
|
||||
|
||||
```python
|
||||
from litgpt.model import GPT
|
||||
import torch.nn as nn
|
||||
|
||||
class CustomGPT(GPT):
|
||||
"""Custom GPT model."""
|
||||
|
||||
def __init__(self, config: MyModelConfig):
|
||||
# Don't call super().__init__() - we reimplement
|
||||
nn.Module.__init__(self)
|
||||
self.config = config
|
||||
|
||||
# Standard components
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
self.transformer = nn.ModuleDict(
|
||||
dict(
|
||||
wte=nn.Embedding(config.vocab_size, config.n_embd),
|
||||
h=nn.ModuleList(CustomBlock(config) for _ in range(config.n_layer)),
|
||||
ln_f=nn.LayerNorm(config.n_embd),
|
||||
)
|
||||
)
|
||||
|
||||
# Custom components
|
||||
if config.use_custom_attention:
|
||||
self.custom_embedding = nn.Linear(config.n_embd, config.n_embd)
|
||||
|
||||
# Initialize weights
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize weights (required)."""
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
if module.bias is not None:
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
|
||||
def forward(self, idx, input_pos=None):
|
||||
"""Forward pass (must match base signature)."""
|
||||
B, T = idx.size()
|
||||
|
||||
# Token embeddings
|
||||
x = self.transformer.wte(idx)
|
||||
|
||||
# Custom embedding modification
|
||||
if self.config.use_custom_attention:
|
||||
x = x + self.custom_embedding(x)
|
||||
|
||||
# Transformer blocks
|
||||
for block in self.transformer.h:
|
||||
x = block(x, input_pos=input_pos)
|
||||
|
||||
# Final norm + LM head
|
||||
x = self.transformer.ln_f(x)
|
||||
return self.lm_head(x)
|
||||
```
|
||||
|
||||
### Step 4: Register Configuration
|
||||
|
||||
Add your config to `litgpt/config.py`:
|
||||
|
||||
```python
|
||||
# In litgpt/config.py
|
||||
configs = [
|
||||
# ... existing configs ...
|
||||
|
||||
# My custom model
|
||||
dict(
|
||||
name="my-model-7b",
|
||||
hf_config=dict(org="myorg", name="my-model-7b"),
|
||||
block_size=4096,
|
||||
vocab_size=32000,
|
||||
n_layer=32,
|
||||
n_head=32,
|
||||
n_embd=4096,
|
||||
custom_param=0.1,
|
||||
),
|
||||
]
|
||||
```
|
||||
|
||||
### Step 5: Use Your Custom Model
|
||||
|
||||
```python
|
||||
from litgpt.api import LLM
|
||||
from my_model import CustomGPT, MyModelConfig
|
||||
|
||||
# Initialize
|
||||
config = MyModelConfig()
|
||||
model = CustomGPT(config)
|
||||
|
||||
# Wrap with LLM API
|
||||
llm = LLM(model=model, tokenizer_dir="path/to/tokenizer")
|
||||
|
||||
# Generate
|
||||
result = llm.generate("Once upon a time", max_new_tokens=100)
|
||||
print(result)
|
||||
```
|
||||
|
||||
## Real Example: Adapter Fine-tuning
|
||||
|
||||
LitGPT's `Adapter` implementation shows a complete custom architecture:
|
||||
|
||||
### Adapter Configuration
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class Config(BaseConfig):
|
||||
"""Adds adapter-specific parameters."""
|
||||
adapter_prompt_length: int = 10
|
||||
adapter_start_layer: int = 2
|
||||
```
|
||||
|
||||
### Adapter GPT Model
|
||||
|
||||
```python
|
||||
class GPT(BaseModel):
|
||||
"""GPT model with adapter layers."""
|
||||
|
||||
def __init__(self, config: Config):
|
||||
nn.Module.__init__(self)
|
||||
self.config = config
|
||||
|
||||
# Standard components
|
||||
self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)
|
||||
self.transformer = nn.ModuleDict(
|
||||
dict(
|
||||
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
|
||||
h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)),
|
||||
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
|
||||
)
|
||||
)
|
||||
|
||||
# Adapter-specific: gating factor
|
||||
self.gating_factor = torch.nn.Parameter(torch.zeros(1))
|
||||
```
|
||||
|
||||
### Adapter Block
|
||||
|
||||
```python
|
||||
class Block(BaseBlock):
|
||||
"""Transformer block with adapter."""
|
||||
|
||||
def __init__(self, config: Config, block_idx: int):
|
||||
super().__init__()
|
||||
self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
|
||||
self.attn = CausalSelfAttention(config, block_idx)
|
||||
self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps)
|
||||
self.mlp = config.mlp_class(config)
|
||||
|
||||
# Adapter: add prefix for certain layers
|
||||
self.adapter_wte = (
|
||||
nn.Embedding(config.adapter_prompt_length, config.n_embd)
|
||||
if block_idx >= config.adapter_start_layer
|
||||
else None
|
||||
)
|
||||
```
|
||||
|
||||
### Adapter Attention
|
||||
|
||||
```python
|
||||
class CausalSelfAttention(BaseCausalSelfAttention):
|
||||
"""Attention with adapter prompts."""
|
||||
|
||||
def forward(self, x: torch.Tensor, ...) -> torch.Tensor:
|
||||
B, T, C = x.size()
|
||||
|
||||
# Add adapter prefix if enabled
|
||||
if self.adapter_wte is not None:
|
||||
adapter_prompts = self.adapter_wte(
|
||||
torch.arange(self.adapter_prompt_length, device=x.device)
|
||||
)
|
||||
adapter_prompts = adapter_prompts.unsqueeze(0).expand(B, -1, -1)
|
||||
x = torch.cat([adapter_prompts, x], dim=1)
|
||||
|
||||
# Standard attention with gating
|
||||
q, k, v = self.attn(x).split(self.n_embd, dim=2)
|
||||
y = self.scaled_dot_product_attention(q, k, v, mask=mask)
|
||||
|
||||
# Apply gating factor
|
||||
y = y * self.gating_factor
|
||||
|
||||
return self.proj(y)
|
||||
```
|
||||
|
||||
See full implementation: `litgpt/finetune/adapter.py`
|
||||
|
||||
## Real Example: AdapterV2
|
||||
|
||||
AdapterV2 shows custom linear layers:
|
||||
|
||||
### AdapterV2Linear
|
||||
|
||||
```python
|
||||
class AdapterV2Linear(torch.nn.Module):
|
||||
"""Linear layer with low-rank adapter."""
|
||||
|
||||
def __init__(self, in_features, out_features, adapter_rank=8, **kwargs):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(in_features, out_features, **kwargs)
|
||||
|
||||
# Adapter: low-rank bottleneck
|
||||
self.adapter_down = torch.nn.Linear(in_features, adapter_rank, bias=False)
|
||||
self.adapter_up = torch.nn.Linear(adapter_rank, out_features, bias=False)
|
||||
|
||||
# Initialize adapter to identity
|
||||
torch.nn.init.zeros_(self.adapter_up.weight)
|
||||
|
||||
def forward(self, x):
|
||||
# Original linear transformation
|
||||
out = self.linear(x)
|
||||
|
||||
# Add adapter contribution
|
||||
adapter_out = self.adapter_up(self.adapter_down(x))
|
||||
return out + adapter_out
|
||||
```
|
||||
|
||||
See full implementation: `litgpt/finetune/adapter_v2.py`
|
||||
|
||||
## Custom Model Checklist
|
||||
|
||||
- [ ] Define `Config` dataclass with all hyperparameters
|
||||
- [ ] Implement custom components (Attention, MLP, Block)
|
||||
- [ ] Create custom `GPT` class
|
||||
- [ ] Implement `_init_weights()` for proper initialization
|
||||
- [ ] Implement `forward()` matching base signature
|
||||
- [ ] Register configuration in `litgpt/config.py`
|
||||
- [ ] Test with small model (100M params) first
|
||||
- [ ] Verify training convergence
|
||||
- [ ] Profile memory usage
|
||||
|
||||
## Testing Your Custom Model
|
||||
|
||||
### Unit Test
|
||||
|
||||
```python
|
||||
import torch
|
||||
from my_model import CustomGPT, MyModelConfig
|
||||
|
||||
def test_custom_model():
|
||||
"""Test custom model forward pass."""
|
||||
config = MyModelConfig(
|
||||
n_layer=2,
|
||||
n_head=4,
|
||||
n_embd=128,
|
||||
vocab_size=1000,
|
||||
block_size=256,
|
||||
)
|
||||
|
||||
model = CustomGPT(config)
|
||||
model.eval()
|
||||
|
||||
# Test forward pass
|
||||
batch_size = 2
|
||||
seq_length = 16
|
||||
idx = torch.randint(0, config.vocab_size, (batch_size, seq_length))
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(idx)
|
||||
|
||||
assert logits.shape == (batch_size, seq_length, config.vocab_size)
|
||||
print("✓ Forward pass works")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_custom_model()
|
||||
```
|
||||
|
||||
### Training Test
|
||||
|
||||
```python
|
||||
from litgpt.api import LLM
|
||||
|
||||
def test_training():
|
||||
"""Test custom model training."""
|
||||
config = MyModelConfig(n_layer=2, n_head=4, n_embd=128)
|
||||
model = CustomGPT(config)
|
||||
|
||||
# Small dataset for testing
|
||||
data = [
|
||||
{"instruction": "Test", "input": "", "output": "OK"}
|
||||
]
|
||||
|
||||
# Should run without errors
|
||||
llm = LLM(model=model)
|
||||
# ... training code ...
|
||||
print("✓ Training works")
|
||||
```
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Adding New Attention Mechanism
|
||||
|
||||
```python
|
||||
class MyAttention(nn.Module):
|
||||
"""Template for custom attention."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.n_head = config.n_head
|
||||
self.n_embd = config.n_embd
|
||||
self.head_size = self.n_embd // self.n_head
|
||||
|
||||
# Q, K, V projections
|
||||
self.q_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
|
||||
self.k_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
|
||||
self.v_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
|
||||
|
||||
# Output projection
|
||||
self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
B, T, C = x.size()
|
||||
|
||||
# Project Q, K, V
|
||||
q = self.q_proj(x).view(B, T, self.n_head, self.head_size)
|
||||
k = self.k_proj(x).view(B, T, self.n_head, self.head_size)
|
||||
v = self.v_proj(x).view(B, T, self.n_head, self.head_size)
|
||||
|
||||
# Custom attention computation here
|
||||
# attn = custom_attention_function(q, k, v, mask)
|
||||
|
||||
# Output projection
|
||||
out = self.out_proj(attn.reshape(B, T, C))
|
||||
return out
|
||||
```
|
||||
|
||||
### Adding Mixture of Experts
|
||||
|
||||
```python
|
||||
class MoELayer(nn.Module):
|
||||
"""Mixture of Experts layer."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.num_experts = config.num_experts
|
||||
self.top_k = config.moe_top_k
|
||||
|
||||
# Router
|
||||
self.router = nn.Linear(config.n_embd, self.num_experts)
|
||||
|
||||
# Experts
|
||||
self.experts = nn.ModuleList([
|
||||
MLP(config) for _ in range(self.num_experts)
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
B, T, C = x.size()
|
||||
|
||||
# Route tokens to experts
|
||||
router_logits = self.router(x) # (B, T, num_experts)
|
||||
router_probs = torch.softmax(router_logits, dim=-1)
|
||||
|
||||
# Select top-k experts
|
||||
top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
|
||||
|
||||
# Process through selected experts
|
||||
output = torch.zeros_like(x)
|
||||
for i in range(self.top_k):
|
||||
expert_idx = top_k_indices[:, :, i]
|
||||
expert_prob = top_k_probs[:, :, i:i+1]
|
||||
|
||||
# Route to expert
|
||||
for expert_id in range(self.num_experts):
|
||||
mask = (expert_idx == expert_id)
|
||||
if mask.any():
|
||||
expert_out = self.experts[expert_id](x[mask])
|
||||
output[mask] += expert_out * expert_prob[mask]
|
||||
|
||||
return output
|
||||
```
|
||||
|
||||
### Adding Positional Encoding
|
||||
|
||||
```python
|
||||
class CustomPositionalEncoding(nn.Module):
|
||||
"""Custom positional encoding."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.n_embd = config.n_embd
|
||||
self.register_buffer(
|
||||
"pos_encoding",
|
||||
self._create_encoding(config.block_size, config.n_embd)
|
||||
)
|
||||
|
||||
def _create_encoding(self, max_len, d_model):
|
||||
"""Create positional encoding matrix."""
|
||||
pos = torch.arange(max_len).unsqueeze(1)
|
||||
div = torch.exp(torch.arange(0, d_model, 2) * -(torch.log(torch.tensor(10000.0)) / d_model))
|
||||
|
||||
encoding = torch.zeros(max_len, d_model)
|
||||
encoding[:, 0::2] = torch.sin(pos * div)
|
||||
encoding[:, 1::2] = torch.cos(pos * div)
|
||||
return encoding
|
||||
|
||||
def forward(self, x):
|
||||
"""Add positional encoding."""
|
||||
return x + self.pos_encoding[:x.size(1), :]
|
||||
```
|
||||
|
||||
## Debugging Tips
|
||||
|
||||
1. **Start small**: Test with 2 layers, 128 hidden size
|
||||
2. **Check shapes**: Print tensor shapes at each step
|
||||
3. **Verify gradients**: Ensure all parameters have gradients
|
||||
4. **Compare to base**: Run same config with base `GPT` model
|
||||
5. **Profile memory**: Use `torch.cuda.memory_summary()`
|
||||
|
||||
## References
|
||||
|
||||
- Base model: `litgpt/model.py`
|
||||
- Configuration: `litgpt/config.py`
|
||||
- Adapter example: `litgpt/finetune/adapter.py`
|
||||
- AdapterV2 example: `litgpt/finetune/adapter_v2.py`
|
||||
- LoRA example: `litgpt/finetune/lora.py`
|
||||
@@ -0,0 +1,451 @@
|
||||
# Distributed Training
|
||||
|
||||
Guide to FSDP (Fully Sharded Data Parallel) distributed training in LitGPT for scaling to multiple GPUs and nodes.
|
||||
|
||||
## Overview
|
||||
|
||||
LitGPT uses **Lightning Fabric** with **FSDP** to distribute training across multiple GPUs. FSDP shards model parameters, gradients, and optimizer states to enable training models larger than single-GPU memory.
|
||||
|
||||
**When to use FSDP**:
|
||||
- Model doesn't fit on single GPU
|
||||
- Want faster training with multi-GPU
|
||||
- Training models >7B parameters
|
||||
- Need to scale across multiple nodes
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Single Node Multi-GPU
|
||||
|
||||
```bash
|
||||
# Train Llama 2 7B on 4 GPUs
|
||||
litgpt finetune_lora meta-llama/Llama-2-7b-hf \
|
||||
--devices 4 \
|
||||
--data JSON \
|
||||
--data.json_path data/alpaca.json
|
||||
```
|
||||
|
||||
FSDP is **automatically enabled** when `devices > 1`.
|
||||
|
||||
### Multi-Node Training
|
||||
|
||||
```bash
|
||||
# Train on 2 nodes with 8 GPUs each (16 total)
|
||||
litgpt finetune_lora meta-llama/Llama-2-70b-hf \
|
||||
--devices 8 \
|
||||
--num_nodes 2 \
|
||||
--data JSON \
|
||||
--data.json_path data/alpaca.json
|
||||
```
|
||||
|
||||
## FSDP Configuration
|
||||
|
||||
### Default FSDP Strategy
|
||||
|
||||
When multiple devices are used, LitGPT applies this FSDP configuration:
|
||||
|
||||
```python
|
||||
from lightning.fabric.strategies import FSDPStrategy
|
||||
from litgpt.model import Block
|
||||
|
||||
strategy = FSDPStrategy(
|
||||
auto_wrap_policy={Block},
|
||||
state_dict_type="full",
|
||||
sharding_strategy="HYBRID_SHARD"
|
||||
)
|
||||
```
|
||||
|
||||
**Parameters**:
|
||||
- `auto_wrap_policy={Block}`: Automatically wraps each transformer `Block` with FSDP
|
||||
- `state_dict_type="full"`: Saves full model (assembled on rank 0) for easy deployment
|
||||
- `sharding_strategy="HYBRID_SHARD"`: Shards parameters, gradients, and optimizer states
|
||||
|
||||
### Sharding Strategies
|
||||
|
||||
| Strategy | Shards | Communication | Use Case |
|
||||
|----------|--------|---------------|----------|
|
||||
| `FULL_SHARD` (ZeRO-3) | Params + Grads + Optim | All-gather before forward/backward | Maximum memory savings |
|
||||
| `SHARD_GRAD_OP` (ZeRO-2) | Grads + Optim only | Reduce-scatter after backward | Faster than FULL_SHARD |
|
||||
| `HYBRID_SHARD` (default) | All (hybrid across nodes) | Optimized for multi-node | Best for clusters |
|
||||
| `NO_SHARD` | None | Broadcast | Single GPU (no FSDP) |
|
||||
|
||||
**Recommendation**: Use default `HYBRID_SHARD` for multi-node, or `FULL_SHARD` for single-node multi-GPU.
|
||||
|
||||
### State Dict Types
|
||||
|
||||
| Type | Behavior | Use Case |
|
||||
|------|----------|----------|
|
||||
| `full` (default) | Gathers all shards on rank 0, saves single file | Easy deployment, inference |
|
||||
| `sharded` | Each rank saves its shard separately | Faster checkpointing, resume training |
|
||||
|
||||
### Auto-Wrap Policy
|
||||
|
||||
FSDP wraps model components based on `auto_wrap_policy`:
|
||||
|
||||
```python
|
||||
auto_wrap_policy={Block} # Wrap each transformer block
|
||||
```
|
||||
|
||||
This means each `Block` (transformer layer) is independently sharded across GPUs. For a 32-layer model on 4 GPUs, each GPU holds ~8 layer shards.
|
||||
|
||||
## Thunder FSDP (Advanced)
|
||||
|
||||
LitGPT includes an experimental **Thunder** extension with enhanced FSDP:
|
||||
|
||||
```bash
|
||||
litgpt pretrain tiny-llama-1.1b \
|
||||
--devices 8 \
|
||||
--num_nodes 1 \
|
||||
--compiler thunder \
|
||||
--strategy fsdp
|
||||
```
|
||||
|
||||
### Thunder FSDP Configuration
|
||||
|
||||
```python
|
||||
from extensions.thunder.pretrain import ThunderFSDPStrategy
|
||||
|
||||
strategy = ThunderFSDPStrategy(
|
||||
sharding_strategy="ZERO3",
|
||||
bucketing_strategy="BLOCK",
|
||||
state_dict_type="full",
|
||||
jit=False,
|
||||
)
|
||||
```
|
||||
|
||||
**Additional Parameters**:
|
||||
- `sharding_strategy`: `"ZERO3"` (full shard), `"ZERO2"` (grad/optim only)
|
||||
- `bucketing_strategy`: `"BLOCK"` (combine ops per block), `"LAYER"` (per layer), `"NONE"` (no bucketing)
|
||||
- `jit`: Whether to apply `thunder.jit(model)` for optimization
|
||||
- `executors`: Tuple of Thunder executors to enable
|
||||
|
||||
**Bucketing Strategy**:
|
||||
- `"BLOCK"` (default): Combines collective operations for layer blocks → fewer communication calls
|
||||
- `"LAYER"`: Combines per layer class
|
||||
- `"NONE"`: No bucketing → more fine-grained but more overhead
|
||||
|
||||
## Pretraining with FSDP
|
||||
|
||||
### Single Node
|
||||
|
||||
```bash
|
||||
litgpt pretrain tiny-llama-1.1b \
|
||||
--devices 8 \
|
||||
--train.global_batch_size 512 \
|
||||
--train.micro_batch_size 8 \
|
||||
--data Alpaca2k
|
||||
```
|
||||
|
||||
**Memory calculation**:
|
||||
- TinyLlama 1.1B: ~4GB model + ~4GB gradients + ~8GB optimizer = 16GB per GPU without FSDP
|
||||
- With FSDP on 8 GPUs: 16GB / 8 = 2GB per GPU ✅ Fits easily
|
||||
|
||||
### Multi-Node
|
||||
|
||||
```bash
|
||||
# Launch on 4 nodes with 8 GPUs each (32 total)
|
||||
litgpt pretrain llama-2-7b \
|
||||
--devices 8 \
|
||||
--num_nodes 4 \
|
||||
--train.global_batch_size 1024 \
|
||||
--train.micro_batch_size 2 \
|
||||
--data RedPajama
|
||||
```
|
||||
|
||||
**Memory calculation**:
|
||||
- Llama 2 7B: ~28GB model + ~28GB gradients + ~56GB optimizer = 112GB total
|
||||
- With FSDP on 32 GPUs: 112GB / 32 = 3.5GB per GPU ✅
|
||||
|
||||
## Fine-tuning with FSDP
|
||||
|
||||
### LoRA Fine-tuning (Recommended)
|
||||
|
||||
LoRA fine-tuning with FSDP for >7B models:
|
||||
|
||||
```bash
|
||||
# Llama 2 70B LoRA on 8 GPUs
|
||||
litgpt finetune_lora meta-llama/Llama-2-70b-hf \
|
||||
--devices 8 \
|
||||
--data JSON \
|
||||
--data.json_path data/alpaca.json \
|
||||
--train.global_batch_size 16 \
|
||||
--train.micro_batch_size 1 \
|
||||
--lora_r 8
|
||||
```
|
||||
|
||||
**Why LoRA with FSDP**:
|
||||
- Base model sharded with FSDP (memory efficient)
|
||||
- Only LoRA adapters trained (fast)
|
||||
- Best of both worlds for large models
|
||||
|
||||
### Full Fine-tuning
|
||||
|
||||
Full fine-tuning with FSDP:
|
||||
|
||||
```bash
|
||||
# Llama 2 7B full fine-tune on 4 GPUs
|
||||
litgpt finetune_full meta-llama/Llama-2-7b-hf \
|
||||
--devices 4 \
|
||||
--data JSON \
|
||||
--data.json_path data/alpaca.json \
|
||||
--train.global_batch_size 16 \
|
||||
--train.micro_batch_size 1 \
|
||||
--train.learning_rate 3e-5
|
||||
```
|
||||
|
||||
## Mixed Precision
|
||||
|
||||
FSDP works with mixed precision for memory savings and speedup:
|
||||
|
||||
```bash
|
||||
# BF16 mixed precision (recommended for A100/H100)
|
||||
litgpt pretrain tiny-llama-1.1b \
|
||||
--devices 8 \
|
||||
--precision bf16-mixed
|
||||
|
||||
# FP16 mixed precision (V100 compatible)
|
||||
litgpt pretrain tiny-llama-1.1b \
|
||||
--devices 8 \
|
||||
--precision 16-mixed
|
||||
```
|
||||
|
||||
**Precision options**:
|
||||
- `bf16-mixed`: BF16 for computation, FP32 for master weights (best for Ampere+)
|
||||
- `16-mixed`: FP16 for computation, FP32 for master weights (V100)
|
||||
- `32-true`: Full FP32 (debugging only, slow)
|
||||
|
||||
## Gradient Accumulation
|
||||
|
||||
Simulate larger batch sizes with gradient accumulation:
|
||||
|
||||
```bash
|
||||
# Simulate global_batch_size=512 with micro_batch_size=2
|
||||
litgpt pretrain tiny-llama-1.1b \
|
||||
--devices 8 \
|
||||
--train.global_batch_size 512 \
|
||||
--train.micro_batch_size 2
|
||||
# Accumulates over 512/(8*2) = 32 steps per optimizer update
|
||||
```
|
||||
|
||||
**Formula**:
|
||||
```
|
||||
Gradient accumulation steps = global_batch_size / (devices × micro_batch_size)
|
||||
```
|
||||
|
||||
## Memory Optimization
|
||||
|
||||
### Out of Memory? Try These
|
||||
|
||||
1. **Increase devices**:
|
||||
```bash
|
||||
--devices 8 # Instead of 4
|
||||
```
|
||||
|
||||
2. **Reduce micro batch size**:
|
||||
```bash
|
||||
--train.micro_batch_size 1 # Instead of 2
|
||||
```
|
||||
|
||||
3. **Lower precision**:
|
||||
```bash
|
||||
--precision bf16-mixed # Instead of 32-true
|
||||
```
|
||||
|
||||
4. **Use FULL_SHARD**:
|
||||
```python
|
||||
strategy = FSDPStrategy(
|
||||
sharding_strategy="FULL_SHARD" # Maximum memory savings
|
||||
)
|
||||
```
|
||||
|
||||
5. **Enable activation checkpointing** (implemented in model):
|
||||
```python
|
||||
# Recomputes activations during backward pass
|
||||
# Trades compute for memory
|
||||
```
|
||||
|
||||
6. **Use QLoRA**:
|
||||
```bash
|
||||
litgpt finetune_lora meta-llama/Llama-2-7b-hf \
|
||||
--quantize bnb.nf4 \
|
||||
--devices 1 # May not need FSDP with quantization
|
||||
```
|
||||
|
||||
## Checkpointing
|
||||
|
||||
### Save Checkpoints
|
||||
|
||||
FSDP automatically handles checkpoint saving:
|
||||
|
||||
```bash
|
||||
litgpt pretrain tiny-llama-1.1b \
|
||||
--devices 8 \
|
||||
--out_dir checkpoints/tinyllama-pretrain
|
||||
# Saves to: checkpoints/tinyllama-pretrain/final/lit_model.pth
|
||||
```
|
||||
|
||||
With `state_dict_type="full"` (default), rank 0 assembles full model and saves single file.
|
||||
|
||||
### Resume Training
|
||||
|
||||
```bash
|
||||
litgpt pretrain tiny-llama-1.1b \
|
||||
--devices 8 \
|
||||
--resume checkpoints/tinyllama-pretrain/
|
||||
# Automatically loads latest checkpoint
|
||||
```
|
||||
|
||||
### Convert to HuggingFace
|
||||
|
||||
```bash
|
||||
python scripts/convert_lit_checkpoint.py \
|
||||
--checkpoint_path checkpoints/tinyllama-pretrain/final/lit_model.pth \
|
||||
--output_dir models/tinyllama-hf
|
||||
```
|
||||
|
||||
## Performance Tuning
|
||||
|
||||
### Communication Backends
|
||||
|
||||
LitGPT uses NCCL for GPU communication:
|
||||
|
||||
```bash
|
||||
# Default (NCCL auto-configured)
|
||||
litgpt pretrain tiny-llama-1.1b --devices 8
|
||||
|
||||
# Explicit NCCL settings (advanced)
|
||||
NCCL_DEBUG=INFO \
|
||||
NCCL_IB_DISABLE=0 \
|
||||
litgpt pretrain tiny-llama-1.1b --devices 8
|
||||
```
|
||||
|
||||
**NCCL Environment Variables**:
|
||||
- `NCCL_DEBUG=INFO`: Enable debug logging
|
||||
- `NCCL_IB_DISABLE=0`: Use InfiniBand (if available)
|
||||
- `NCCL_SOCKET_IFNAME=eth0`: Specify network interface
|
||||
|
||||
### Multi-Node Setup
|
||||
|
||||
**Option 1: SLURM**
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
#SBATCH --nodes=4
|
||||
#SBATCH --gpus-per-node=8
|
||||
#SBATCH --ntasks-per-node=1
|
||||
|
||||
srun litgpt pretrain llama-2-7b \
|
||||
--devices 8 \
|
||||
--num_nodes 4 \
|
||||
--data RedPajama
|
||||
```
|
||||
|
||||
**Option 2: torchrun**
|
||||
|
||||
```bash
|
||||
# On each node, run:
|
||||
torchrun \
|
||||
--nproc_per_node=8 \
|
||||
--nnodes=4 \
|
||||
--node_rank=$NODE_RANK \
|
||||
--master_addr=$MASTER_ADDR \
|
||||
--master_port=29500 \
|
||||
-m litgpt pretrain llama-2-7b
|
||||
```
|
||||
|
||||
### Profiling
|
||||
|
||||
Enable profiling to identify bottlenecks:
|
||||
|
||||
```bash
|
||||
litgpt pretrain tiny-llama-1.1b \
|
||||
--devices 8 \
|
||||
--train.max_steps 100 \
|
||||
--profile
|
||||
# Generates profiling report
|
||||
```
|
||||
|
||||
## Example Configurations
|
||||
|
||||
### Llama 2 7B on 4× A100 (40GB)
|
||||
|
||||
```bash
|
||||
litgpt finetune_lora meta-llama/Llama-2-7b-hf \
|
||||
--devices 4 \
|
||||
--precision bf16-mixed \
|
||||
--train.global_batch_size 64 \
|
||||
--train.micro_batch_size 4 \
|
||||
--train.max_seq_length 2048 \
|
||||
--lora_r 8 \
|
||||
--data JSON \
|
||||
--data.json_path data/alpaca.json
|
||||
```
|
||||
|
||||
**Memory per GPU**: ~20GB
|
||||
**Throughput**: ~5 samples/sec
|
||||
|
||||
### Llama 2 70B on 8× A100 (80GB)
|
||||
|
||||
```bash
|
||||
litgpt finetune_lora meta-llama/Llama-2-70b-hf \
|
||||
--devices 8 \
|
||||
--precision bf16-mixed \
|
||||
--train.global_batch_size 32 \
|
||||
--train.micro_batch_size 1 \
|
||||
--train.max_seq_length 2048 \
|
||||
--lora_r 8 \
|
||||
--data JSON \
|
||||
--data.json_path data/alpaca.json
|
||||
```
|
||||
|
||||
**Memory per GPU**: ~70GB
|
||||
**Throughput**: ~1 sample/sec
|
||||
|
||||
### Llama 3 405B on 64× H100 (80GB)
|
||||
|
||||
```bash
|
||||
litgpt finetune_lora meta-llama/Llama-3.1-405B \
|
||||
--devices 8 \
|
||||
--num_nodes 8 \
|
||||
--precision bf16-mixed \
|
||||
--train.global_batch_size 128 \
|
||||
--train.micro_batch_size 1 \
|
||||
--train.max_seq_length 4096 \
|
||||
--lora_r 16 \
|
||||
--data JSON \
|
||||
--data.json_path data/alpaca.json
|
||||
```
|
||||
|
||||
**Memory per GPU**: ~60GB
|
||||
**Requires**: 64 H100 GPUs (8 nodes × 8 GPUs)
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "CUDA out of memory"
|
||||
|
||||
1. Reduce `micro_batch_size`
|
||||
2. Increase `devices` (more sharding)
|
||||
3. Lower `max_seq_length`
|
||||
4. Use `bf16-mixed` precision
|
||||
5. Try QLoRA (`--quantize bnb.nf4`)
|
||||
|
||||
### "NCCL error" or Slow Communication
|
||||
|
||||
1. Check network connectivity between nodes
|
||||
2. Enable InfiniBand: `NCCL_IB_DISABLE=0`
|
||||
3. Verify NCCL version: `python -c "import torch; print(torch.cuda.nccl.version())"`
|
||||
4. Test with NCCL tests: `$NCCL_HOME/build/all_reduce_perf -b 8 -e 128M`
|
||||
|
||||
### Training Slower Than Expected
|
||||
|
||||
1. Profile with `--profile`
|
||||
2. Check GPU utilization: `nvidia-smi dmon`
|
||||
3. Verify data loading isn't bottleneck
|
||||
4. Increase `micro_batch_size` if memory allows
|
||||
5. Use Thunder FSDP with bucketing
|
||||
|
||||
## References
|
||||
|
||||
- FSDP configuration: `litgpt/pretrain.py:setup()`
|
||||
- Thunder FSDP: `extensions/thunder/pretrain.py`
|
||||
- Memory optimization guide: `tutorials/oom.md`
|
||||
- Lightning Fabric docs: https://lightning.ai/docs/fabric/
|
||||
@@ -0,0 +1,336 @@
|
||||
# Supported Models
|
||||
|
||||
Complete list of model architectures supported by LitGPT with parameter sizes and variants.
|
||||
|
||||
## Overview
|
||||
|
||||
LitGPT supports **20+ model families** with **100+ model variants** ranging from 135M to 405B parameters.
|
||||
|
||||
**List all models**:
|
||||
```bash
|
||||
litgpt download list
|
||||
```
|
||||
|
||||
**List pretrain-capable models**:
|
||||
```bash
|
||||
litgpt pretrain list
|
||||
```
|
||||
|
||||
## Model Families
|
||||
|
||||
### Llama Family
|
||||
|
||||
**Llama 3, 3.1, 3.2, 3.3**:
|
||||
- **Sizes**: 1B, 3B, 8B, 70B, 405B
|
||||
- **Use Cases**: General-purpose, long-context (128K), multimodal
|
||||
- **Best For**: Production applications, research, instruction following
|
||||
|
||||
**Code Llama**:
|
||||
- **Sizes**: 7B, 13B, 34B, 70B
|
||||
- **Use Cases**: Code generation, completion, infilling
|
||||
- **Best For**: Programming assistants, code analysis
|
||||
|
||||
**Function Calling Llama 2**:
|
||||
- **Sizes**: 7B
|
||||
- **Use Cases**: Tool use, API integration
|
||||
- **Best For**: Agents, function execution
|
||||
|
||||
**Llama 2**:
|
||||
- **Sizes**: 7B, 13B, 70B
|
||||
- **Use Cases**: General-purpose (predecessor to Llama 3)
|
||||
- **Best For**: Established baselines, research comparisons
|
||||
|
||||
**Llama 3.1 Nemotron**:
|
||||
- **Sizes**: 70B
|
||||
- **Use Cases**: NVIDIA-optimized variant
|
||||
- **Best For**: Enterprise deployments
|
||||
|
||||
**TinyLlama**:
|
||||
- **Sizes**: 1.1B
|
||||
- **Use Cases**: Edge devices, resource-constrained environments
|
||||
- **Best For**: Fast inference, mobile deployment
|
||||
|
||||
**OpenLLaMA**:
|
||||
- **Sizes**: 3B, 7B, 13B
|
||||
- **Use Cases**: Open-source Llama reproduction
|
||||
- **Best For**: Research, education
|
||||
|
||||
**Vicuna**:
|
||||
- **Sizes**: 7B, 13B, 33B
|
||||
- **Use Cases**: Chatbot, instruction following
|
||||
- **Best For**: Conversational AI
|
||||
|
||||
**R1 Distill Llama**:
|
||||
- **Sizes**: 8B, 70B
|
||||
- **Use Cases**: Distilled reasoning models
|
||||
- **Best For**: Efficient reasoning tasks
|
||||
|
||||
**MicroLlama**:
|
||||
- **Sizes**: 300M
|
||||
- **Use Cases**: Extremely small Llama variant
|
||||
- **Best For**: Prototyping, testing
|
||||
|
||||
**Platypus**:
|
||||
- **Sizes**: 7B, 13B, 70B
|
||||
- **Use Cases**: STEM-focused fine-tune
|
||||
- **Best For**: Science, math, technical domains
|
||||
|
||||
### Mistral Family
|
||||
|
||||
**Mistral**:
|
||||
- **Sizes**: 7B, 123B
|
||||
- **Use Cases**: Efficient open models, long-context
|
||||
- **Best For**: Cost-effective deployments
|
||||
|
||||
**Mathstral**:
|
||||
- **Sizes**: 7B
|
||||
- **Use Cases**: Math reasoning
|
||||
- **Best For**: Mathematical problem solving
|
||||
|
||||
**Mixtral MoE**:
|
||||
- **Sizes**: 8×7B (47B total, 13B active), 8×22B (141B total, 39B active)
|
||||
- **Use Cases**: Sparse mixture of experts
|
||||
- **Best For**: High capacity with lower compute
|
||||
|
||||
### Falcon Family
|
||||
|
||||
**Falcon**:
|
||||
- **Sizes**: 7B, 40B, 180B
|
||||
- **Use Cases**: Open-source models from TII
|
||||
- **Best For**: Multilingual applications
|
||||
|
||||
**Falcon 3**:
|
||||
- **Sizes**: 1B, 3B, 7B, 10B
|
||||
- **Use Cases**: Newer Falcon generation
|
||||
- **Best For**: Efficient multilingual models
|
||||
|
||||
### Phi Family (Microsoft)
|
||||
|
||||
**Phi 1.5 & 2**:
|
||||
- **Sizes**: 1.3B, 2.7B
|
||||
- **Use Cases**: Small language models with strong performance
|
||||
- **Best For**: Edge deployment, low-resource environments
|
||||
|
||||
**Phi 3 & 3.5**:
|
||||
- **Sizes**: 3.8B
|
||||
- **Use Cases**: Improved small models
|
||||
- **Best For**: Mobile, browser-based applications
|
||||
|
||||
**Phi 4**:
|
||||
- **Sizes**: 14B
|
||||
- **Use Cases**: Medium-size high-performance model
|
||||
- **Best For**: Balance of size and capability
|
||||
|
||||
**Phi 4 Mini Instruct**:
|
||||
- **Sizes**: 3.8B
|
||||
- **Use Cases**: Instruction-tuned variant
|
||||
- **Best For**: Chat, task completion
|
||||
|
||||
### Gemma Family (Google)
|
||||
|
||||
**Gemma**:
|
||||
- **Sizes**: 2B, 7B
|
||||
- **Use Cases**: Google's open models
|
||||
- **Best For**: Research, education
|
||||
|
||||
**Gemma 2**:
|
||||
- **Sizes**: 2B, 9B, 27B
|
||||
- **Use Cases**: Second generation improvements
|
||||
- **Best For**: Enhanced performance
|
||||
|
||||
**Gemma 3**:
|
||||
- **Sizes**: 1B, 4B, 12B, 27B
|
||||
- **Use Cases**: Latest Gemma generation
|
||||
- **Best For**: State-of-the-art open models
|
||||
|
||||
**CodeGemma**:
|
||||
- **Sizes**: 7B
|
||||
- **Use Cases**: Code-specialized Gemma
|
||||
- **Best For**: Code generation, analysis
|
||||
|
||||
### Qwen Family (Alibaba)
|
||||
|
||||
**Qwen2.5**:
|
||||
- **Sizes**: 0.5B, 1.5B, 3B, 7B, 14B, 32B, 72B
|
||||
- **Use Cases**: General-purpose multilingual models
|
||||
- **Best For**: Chinese/English applications
|
||||
|
||||
**Qwen2.5 Coder**:
|
||||
- **Sizes**: 0.5B, 1.5B, 3B, 7B, 14B, 32B
|
||||
- **Use Cases**: Code-specialized variants
|
||||
- **Best For**: Programming in multiple languages
|
||||
|
||||
**Qwen2.5 Math**:
|
||||
- **Sizes**: 1.5B, 7B, 72B
|
||||
- **Use Cases**: Mathematical reasoning
|
||||
- **Best For**: Math problems, STEM education
|
||||
|
||||
**QwQ & QwQ-Preview**:
|
||||
- **Sizes**: 32B
|
||||
- **Use Cases**: Question-answering focus
|
||||
- **Best For**: Reasoning tasks
|
||||
|
||||
### Pythia Family (EleutherAI)
|
||||
|
||||
**Pythia**:
|
||||
- **Sizes**: 14M, 31M, 70M, 160M, 410M, 1B, 1.4B, 2.8B, 6.9B, 12B
|
||||
- **Use Cases**: Research, interpretability
|
||||
- **Best For**: Scientific studies, ablations
|
||||
|
||||
### StableLM Family (Stability AI)
|
||||
|
||||
**StableLM**:
|
||||
- **Sizes**: 3B, 7B
|
||||
- **Use Cases**: Open models from Stability AI
|
||||
- **Best For**: Research, commercial use
|
||||
|
||||
**StableLM Zephyr**:
|
||||
- **Sizes**: 3B
|
||||
- **Use Cases**: Instruction-tuned variant
|
||||
- **Best For**: Chat applications
|
||||
|
||||
**StableCode**:
|
||||
- **Sizes**: 3B
|
||||
- **Use Cases**: Code generation
|
||||
- **Best For**: Programming tasks
|
||||
|
||||
**FreeWilly2 (Stable Beluga 2)**:
|
||||
- **Sizes**: 70B
|
||||
- **Use Cases**: Large Stability AI model
|
||||
- **Best For**: High-capability tasks
|
||||
|
||||
### Other Models
|
||||
|
||||
**Danube2**:
|
||||
- **Sizes**: 1.8B
|
||||
- **Use Cases**: Efficient small model
|
||||
- **Best For**: Resource-constrained environments
|
||||
|
||||
**Dolly**:
|
||||
- **Sizes**: 3B, 7B, 12B
|
||||
- **Use Cases**: Databricks' instruction-following model
|
||||
- **Best For**: Enterprise applications
|
||||
|
||||
**LongChat**:
|
||||
- **Sizes**: 7B, 13B
|
||||
- **Use Cases**: Extended context windows
|
||||
- **Best For**: Long-document understanding
|
||||
|
||||
**Nous-Hermes**:
|
||||
- **Sizes**: 7B, 13B, 70B
|
||||
- **Use Cases**: Instruction-following fine-tune
|
||||
- **Best For**: Task completion, reasoning
|
||||
|
||||
**OLMo**:
|
||||
- **Sizes**: 1B, 7B
|
||||
- **Use Cases**: Allen AI's fully open model
|
||||
- **Best For**: Research transparency
|
||||
|
||||
**RedPajama-INCITE**:
|
||||
- **Sizes**: 3B, 7B
|
||||
- **Use Cases**: Open reproduction project
|
||||
- **Best For**: Research, education
|
||||
|
||||
**Salamandra**:
|
||||
- **Sizes**: 2B, 7B
|
||||
- **Use Cases**: Multilingual European model
|
||||
- **Best For**: European language support
|
||||
|
||||
**SmolLM2**:
|
||||
- **Sizes**: 135M, 360M, 1.7B
|
||||
- **Use Cases**: Ultra-small models
|
||||
- **Best For**: Edge devices, testing
|
||||
|
||||
## Download Examples
|
||||
|
||||
**Download specific model**:
|
||||
```bash
|
||||
litgpt download meta-llama/Llama-3.2-1B
|
||||
litgpt download microsoft/phi-2
|
||||
litgpt download google/gemma-2-9b
|
||||
```
|
||||
|
||||
**Download with HuggingFace token** (for gated models):
|
||||
```bash
|
||||
export HF_TOKEN=hf_...
|
||||
litgpt download meta-llama/Llama-3.1-405B
|
||||
```
|
||||
|
||||
## Model Selection Guide
|
||||
|
||||
### By Use Case
|
||||
|
||||
**General Chat/Instruction Following**:
|
||||
- Small: Phi-2 (2.7B), TinyLlama (1.1B)
|
||||
- Medium: Llama-3.2-8B, Mistral-7B
|
||||
- Large: Llama-3.1-70B, Mixtral-8x22B
|
||||
|
||||
**Code Generation**:
|
||||
- Small: Qwen2.5-Coder-3B
|
||||
- Medium: CodeLlama-13B, CodeGemma-7B
|
||||
- Large: CodeLlama-70B, Qwen2.5-Coder-32B
|
||||
|
||||
**Math/Reasoning**:
|
||||
- Small: Qwen2.5-Math-1.5B
|
||||
- Medium: Mathstral-7B, Qwen2.5-Math-7B
|
||||
- Large: QwQ-32B, Qwen2.5-Math-72B
|
||||
|
||||
**Multilingual**:
|
||||
- Small: SmolLM2-1.7B
|
||||
- Medium: Qwen2.5-7B, Falcon-7B
|
||||
- Large: Qwen2.5-72B
|
||||
|
||||
**Research/Education**:
|
||||
- Pythia family (14M-12B for ablations)
|
||||
- OLMo (fully open)
|
||||
- TinyLlama (fast iteration)
|
||||
|
||||
### By Hardware
|
||||
|
||||
**Consumer GPU (8-16GB VRAM)**:
|
||||
- Phi-2 (2.7B)
|
||||
- TinyLlama (1.1B)
|
||||
- Gemma-2B
|
||||
- SmolLM2 family
|
||||
|
||||
**Single A100 (40-80GB)**:
|
||||
- Llama-3.2-8B
|
||||
- Mistral-7B
|
||||
- CodeLlama-13B
|
||||
- Gemma-9B
|
||||
|
||||
**Multi-GPU (200GB+ total)**:
|
||||
- Llama-3.1-70B (TP=4)
|
||||
- Mixtral-8x22B (TP=2)
|
||||
- Falcon-40B
|
||||
|
||||
**Large Cluster**:
|
||||
- Llama-3.1-405B (FSDP)
|
||||
- Falcon-180B
|
||||
|
||||
## Model Capabilities
|
||||
|
||||
### Context Lengths
|
||||
|
||||
| Model | Context Window |
|
||||
|-------|----------------|
|
||||
| Llama 3.1 | 128K |
|
||||
| Llama 3.2/3.3 | 128K |
|
||||
| Mistral-123B | 128K |
|
||||
| Mixtral | 32K |
|
||||
| Gemma 2 | 8K |
|
||||
| Phi-3 | 128K |
|
||||
| Qwen2.5 | 32K |
|
||||
|
||||
### Training Data
|
||||
|
||||
- **Llama 3**: 15T tokens (multilingual)
|
||||
- **Mistral**: Web data, code
|
||||
- **Qwen**: Multilingual (Chinese/English focus)
|
||||
- **Pythia**: The Pile (controlled training)
|
||||
|
||||
## References
|
||||
|
||||
- LitGPT GitHub: https://github.com/Lightning-AI/litgpt
|
||||
- Model configs: `litgpt/config.py`
|
||||
- Download tutorial: `tutorials/download_model_weights.md`
|
||||
@@ -0,0 +1,619 @@
|
||||
# Training Recipes
|
||||
|
||||
Complete hyperparameter configurations for LoRA, QLoRA, and full fine-tuning across different model sizes.
|
||||
|
||||
## Overview
|
||||
|
||||
LitGPT provides optimized training configurations in `config_hub/finetune/` for various model architectures and fine-tuning methods.
|
||||
|
||||
**Key Configuration Files**:
|
||||
- `config_hub/finetune/*/lora.yaml` - LoRA fine-tuning
|
||||
- `config_hub/finetune/*/qlora.yaml` - 4-bit quantized LoRA
|
||||
- `config_hub/finetune/*/full.yaml` - Full fine-tuning
|
||||
|
||||
## LoRA Fine-tuning Recipes
|
||||
|
||||
### TinyLlama 1.1B LoRA
|
||||
|
||||
**Configuration**:
|
||||
```yaml
|
||||
global_batch_size: 8
|
||||
micro_batch_size: 8
|
||||
lr_warmup_steps: 10
|
||||
epochs: 3
|
||||
max_seq_length: 512
|
||||
|
||||
# LoRA specific
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
```
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
litgpt finetune_lora TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
|
||||
--data JSON \
|
||||
--data.json_path data/alpaca_sample.json \
|
||||
--train.global_batch_size 8 \
|
||||
--train.micro_batch_size 8 \
|
||||
--train.lr_warmup_steps 10 \
|
||||
--train.epochs 3 \
|
||||
--train.max_seq_length 512 \
|
||||
--lora_r 8 \
|
||||
--lora_alpha 16
|
||||
```
|
||||
|
||||
**Memory**: ~4GB VRAM
|
||||
**Time**: ~30 minutes on RTX 3090
|
||||
|
||||
### Llama 2 7B LoRA
|
||||
|
||||
**Configuration**:
|
||||
```yaml
|
||||
global_batch_size: 8
|
||||
micro_batch_size: 2
|
||||
lr_warmup_steps: 10
|
||||
epochs: 4
|
||||
max_seq_length: 512
|
||||
|
||||
# LoRA specific
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
```
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
litgpt finetune_lora meta-llama/Llama-2-7b-hf \
|
||||
--data JSON \
|
||||
--data.json_path data/alpaca.json \
|
||||
--train.global_batch_size 8 \
|
||||
--train.micro_batch_size 2 \
|
||||
--train.lr_warmup_steps 10 \
|
||||
--train.epochs 4 \
|
||||
--lora_r 8 \
|
||||
--lora_alpha 16
|
||||
```
|
||||
|
||||
**Memory**: ~16GB VRAM
|
||||
**Gradient Accumulation**: 4 steps (8 / 2)
|
||||
**Time**: ~6 hours on A100
|
||||
|
||||
### Llama 3 8B LoRA
|
||||
|
||||
**Configuration**:
|
||||
```yaml
|
||||
global_batch_size: 8
|
||||
micro_batch_size: 1
|
||||
lr_warmup_steps: 10
|
||||
epochs: 2
|
||||
max_seq_length: 512
|
||||
|
||||
# LoRA specific
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
```
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
litgpt finetune_lora meta-llama/Llama-3.2-8B \
|
||||
--data JSON \
|
||||
--data.json_path data/custom_dataset.json \
|
||||
--train.global_batch_size 8 \
|
||||
--train.micro_batch_size 1 \
|
||||
--train.lr_warmup_steps 10 \
|
||||
--train.epochs 2 \
|
||||
--lora_r 8
|
||||
```
|
||||
|
||||
**Memory**: ~20GB VRAM
|
||||
**Gradient Accumulation**: 8 steps
|
||||
**Time**: ~8 hours on A100
|
||||
|
||||
### Mistral 7B LoRA
|
||||
|
||||
**Configuration**:
|
||||
```yaml
|
||||
global_batch_size: 8
|
||||
micro_batch_size: 2
|
||||
lr_warmup_steps: 10
|
||||
epochs: 4
|
||||
max_seq_length: 512
|
||||
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
```
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
litgpt finetune_lora mistralai/Mistral-7B-v0.1 \
|
||||
--data JSON \
|
||||
--data.json_path data/alpaca.json \
|
||||
--train.global_batch_size 8 \
|
||||
--train.micro_batch_size 2 \
|
||||
--train.epochs 4 \
|
||||
--lora_r 8
|
||||
```
|
||||
|
||||
**Memory**: ~16GB VRAM
|
||||
|
||||
### Phi-2 LoRA
|
||||
|
||||
**Configuration**:
|
||||
```yaml
|
||||
global_batch_size: 8
|
||||
micro_batch_size: 4
|
||||
lr_warmup_steps: 10
|
||||
epochs: 1
|
||||
max_seq_length: 512
|
||||
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
```
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
litgpt finetune_lora microsoft/phi-2 \
|
||||
--data JSON \
|
||||
--data.json_path data/alpaca_sample.json \
|
||||
--train.global_batch_size 8 \
|
||||
--train.micro_batch_size 4 \
|
||||
--train.epochs 1 \
|
||||
--lora_r 8
|
||||
```
|
||||
|
||||
**Memory**: ~8GB VRAM
|
||||
**Time**: ~20 minutes on RTX 3090
|
||||
|
||||
### Falcon 7B LoRA
|
||||
|
||||
**Configuration**:
|
||||
```yaml
|
||||
global_batch_size: 8
|
||||
micro_batch_size: 1
|
||||
lr_warmup_steps: 10
|
||||
epochs: 4
|
||||
max_seq_length: 512
|
||||
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
```
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
litgpt finetune_lora tiiuae/falcon-7b \
|
||||
--data JSON \
|
||||
--data.json_path data/alpaca.json \
|
||||
--train.global_batch_size 8 \
|
||||
--train.micro_batch_size 1 \
|
||||
--train.epochs 4 \
|
||||
--lora_r 8
|
||||
```
|
||||
|
||||
**Memory**: ~18GB VRAM
|
||||
|
||||
### Gemma 7B LoRA
|
||||
|
||||
**Configuration**:
|
||||
```yaml
|
||||
global_batch_size: 6
|
||||
micro_batch_size: 1
|
||||
lr_warmup_steps: 200
|
||||
epochs: 2
|
||||
max_seq_length: 512
|
||||
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
```
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
litgpt finetune_lora google/gemma-7b \
|
||||
--data JSON \
|
||||
--data.json_path data/alpaca.json \
|
||||
--train.global_batch_size 6 \
|
||||
--train.micro_batch_size 1 \
|
||||
--train.lr_warmup_steps 200 \
|
||||
--train.epochs 2 \
|
||||
--lora_r 8
|
||||
```
|
||||
|
||||
**Memory**: ~18GB VRAM
|
||||
**Note**: Longer warmup (200 steps) for stability
|
||||
|
||||
## QLoRA Fine-tuning Recipes
|
||||
|
||||
QLoRA uses 4-bit quantization to reduce memory by ~75%.
|
||||
|
||||
### TinyLlama 1.1B QLoRA
|
||||
|
||||
**Configuration**:
|
||||
```yaml
|
||||
global_batch_size: 8
|
||||
micro_batch_size: 8
|
||||
lr_warmup_steps: 10
|
||||
epochs: 3
|
||||
max_seq_length: 512
|
||||
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
quantize: "bnb.nf4"
|
||||
```
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
litgpt finetune_lora TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
|
||||
--quantize bnb.nf4 \
|
||||
--data JSON \
|
||||
--data.json_path data/alpaca_sample.json \
|
||||
--train.global_batch_size 8 \
|
||||
--train.micro_batch_size 8 \
|
||||
--train.epochs 3 \
|
||||
--lora_r 8
|
||||
```
|
||||
|
||||
**Memory**: ~2GB VRAM (75% reduction)
|
||||
|
||||
### Llama 2 7B QLoRA
|
||||
|
||||
**Configuration**:
|
||||
```yaml
|
||||
global_batch_size: 8
|
||||
micro_batch_size: 2
|
||||
lr_warmup_steps: 10
|
||||
epochs: 4
|
||||
max_seq_length: 512
|
||||
min_lr: 6.0e-5
|
||||
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
quantize: "bnb.nf4"
|
||||
```
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
litgpt finetune_lora meta-llama/Llama-2-7b-hf \
|
||||
--quantize bnb.nf4 \
|
||||
--data JSON \
|
||||
--data.json_path data/alpaca.json \
|
||||
--train.global_batch_size 8 \
|
||||
--train.micro_batch_size 2 \
|
||||
--train.epochs 4 \
|
||||
--lora_r 8
|
||||
```
|
||||
|
||||
**Memory**: ~6GB VRAM (consumer GPU friendly)
|
||||
|
||||
### Llama 3 8B QLoRA
|
||||
|
||||
**Configuration**:
|
||||
```yaml
|
||||
global_batch_size: 8
|
||||
micro_batch_size: 2
|
||||
lr_warmup_steps: 10
|
||||
epochs: 2
|
||||
max_seq_length: 512
|
||||
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
quantize: "bnb.nf4"
|
||||
```
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
litgpt finetune_lora meta-llama/Llama-3.2-8B \
|
||||
--quantize bnb.nf4 \
|
||||
--data JSON \
|
||||
--data.json_path data/custom_dataset.json \
|
||||
--train.global_batch_size 8 \
|
||||
--train.micro_batch_size 2 \
|
||||
--train.epochs 2 \
|
||||
--lora_r 8
|
||||
```
|
||||
|
||||
**Memory**: ~8GB VRAM
|
||||
|
||||
### Mistral 7B QLoRA
|
||||
|
||||
**Configuration**:
|
||||
```yaml
|
||||
global_batch_size: 8
|
||||
micro_batch_size: 2
|
||||
lr_warmup_steps: 10
|
||||
epochs: 4
|
||||
max_seq_length: 512
|
||||
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
quantize: "bnb.nf4"
|
||||
```
|
||||
|
||||
**Memory**: ~6GB VRAM
|
||||
|
||||
### Phi-2 QLoRA
|
||||
|
||||
**Configuration**:
|
||||
```yaml
|
||||
global_batch_size: 8
|
||||
micro_batch_size: 4
|
||||
lr_warmup_steps: 10
|
||||
epochs: 1
|
||||
max_seq_length: 512
|
||||
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
quantize: "bnb.nf4"
|
||||
```
|
||||
|
||||
**Memory**: ~3GB VRAM
|
||||
|
||||
### Falcon 7B QLoRA
|
||||
|
||||
**Configuration**:
|
||||
```yaml
|
||||
global_batch_size: 8
|
||||
micro_batch_size: 1
|
||||
lr_warmup_steps: 10
|
||||
epochs: 4
|
||||
max_seq_length: 512
|
||||
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
quantize: "bnb.nf4"
|
||||
```
|
||||
|
||||
**Memory**: ~6GB VRAM
|
||||
|
||||
### Gemma 2B QLoRA
|
||||
|
||||
**Configuration**:
|
||||
```yaml
|
||||
global_batch_size: 6
|
||||
micro_batch_size: 2
|
||||
lr_warmup_steps: 200
|
||||
epochs: 2
|
||||
max_seq_length: 512
|
||||
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
quantize: "bnb.nf4"
|
||||
```
|
||||
|
||||
**Memory**: ~3GB VRAM
|
||||
|
||||
### Gemma 7B QLoRA
|
||||
|
||||
**Configuration**:
|
||||
```yaml
|
||||
global_batch_size: 6
|
||||
micro_batch_size: 1
|
||||
lr_warmup_steps: 200
|
||||
epochs: 2
|
||||
max_seq_length: 512
|
||||
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
quantize: "bnb.nf4"
|
||||
```
|
||||
|
||||
**Memory**: ~6GB VRAM
|
||||
|
||||
## Full Fine-tuning Recipes
|
||||
|
||||
Full fine-tuning updates all model parameters (requires more memory).
|
||||
|
||||
### TinyLlama 1.1B Full
|
||||
|
||||
**Configuration**:
|
||||
```yaml
|
||||
global_batch_size: 8
|
||||
micro_batch_size: 2
|
||||
lr_warmup_steps: 100
|
||||
epochs: 3
|
||||
max_seq_length: 512
|
||||
learning_rate: 5e-5
|
||||
```
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
litgpt finetune_full TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
|
||||
--data JSON \
|
||||
--data.json_path data/alpaca.json \
|
||||
--train.global_batch_size 8 \
|
||||
--train.micro_batch_size 2 \
|
||||
--train.lr_warmup_steps 100 \
|
||||
--train.epochs 3 \
|
||||
--train.learning_rate 5e-5
|
||||
```
|
||||
|
||||
**Memory**: ~12GB VRAM
|
||||
**Time**: ~4 hours on A100
|
||||
|
||||
### Phi-2 Full
|
||||
|
||||
**Configuration**:
|
||||
```yaml
|
||||
global_batch_size: 8
|
||||
micro_batch_size: 1
|
||||
lr_warmup_steps: 100
|
||||
epochs: 2
|
||||
max_seq_length: 512
|
||||
learning_rate: 3e-5
|
||||
```
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
litgpt finetune_full microsoft/phi-2 \
|
||||
--data JSON \
|
||||
--data.json_path data/alpaca.json \
|
||||
--train.global_batch_size 8 \
|
||||
--train.micro_batch_size 1 \
|
||||
--train.epochs 2 \
|
||||
--train.learning_rate 3e-5
|
||||
```
|
||||
|
||||
**Memory**: ~24GB VRAM
|
||||
|
||||
## Common Hyperparameter Patterns
|
||||
|
||||
### Learning Rates
|
||||
|
||||
| Model Size | LoRA LR | Full Fine-tune LR |
|
||||
|------------|---------|-------------------|
|
||||
| <2B | 3e-4 | 5e-5 |
|
||||
| 2-10B | 1e-4 | 3e-5 |
|
||||
| 10-70B | 5e-5 | 1e-5 |
|
||||
|
||||
### LoRA Rank (r)
|
||||
|
||||
- **r=8**: Default, good balance (recommended)
|
||||
- **r=16**: More capacity, 2× trainable params
|
||||
- **r=32**: Maximum capacity, slower training
|
||||
- **r=4**: Minimal, fastest training
|
||||
|
||||
**Rule of thumb**: Start with r=8, increase if underfitting.
|
||||
|
||||
### Batch Sizes
|
||||
|
||||
| GPU VRAM | Micro Batch | Global Batch |
|
||||
|----------|-------------|--------------|
|
||||
| 8GB | 1 | 8 |
|
||||
| 16GB | 2 | 8-16 |
|
||||
| 40GB | 4 | 16-32 |
|
||||
| 80GB | 8 | 32-64 |
|
||||
|
||||
### Warmup Steps
|
||||
|
||||
- **Small models (<2B)**: 10-50 steps
|
||||
- **Medium models (2-10B)**: 100-200 steps
|
||||
- **Large models (>10B)**: 200-500 steps
|
||||
|
||||
### Epochs
|
||||
|
||||
- **Instruction tuning**: 1-3 epochs
|
||||
- **Domain adaptation**: 3-5 epochs
|
||||
- **Small datasets (<10K)**: 5-10 epochs
|
||||
|
||||
## Advanced Configurations
|
||||
|
||||
### Custom Learning Rate Schedule
|
||||
|
||||
```bash
|
||||
litgpt finetune_lora meta-llama/Llama-2-7b-hf \
|
||||
--train.learning_rate 3e-4 \
|
||||
--train.lr_warmup_steps 100 \
|
||||
--train.min_lr 3e-6 \
|
||||
--train.lr_decay_iters 10000
|
||||
```
|
||||
|
||||
### Gradient Accumulation
|
||||
|
||||
```bash
|
||||
# Simulate global_batch_size=128 with 16GB GPU
|
||||
litgpt finetune_lora meta-llama/Llama-2-7b-hf \
|
||||
--train.global_batch_size 128 \
|
||||
--train.micro_batch_size 2
|
||||
# Accumulates over 64 steps (128 / 2)
|
||||
```
|
||||
|
||||
### Mixed Precision
|
||||
|
||||
```bash
|
||||
litgpt finetune_lora meta-llama/Llama-2-7b-hf \
|
||||
--precision bf16-mixed # BF16 mixed precision
|
||||
# or
|
||||
--precision 16-mixed # FP16 mixed precision
|
||||
```
|
||||
|
||||
### Longer Context
|
||||
|
||||
```bash
|
||||
litgpt finetune_lora meta-llama/Llama-3.1-8B \
|
||||
--train.max_seq_length 8192 \
|
||||
--train.micro_batch_size 1 # Reduce batch for memory
|
||||
```
|
||||
|
||||
## Memory Optimization
|
||||
|
||||
### Out of Memory? Try These
|
||||
|
||||
1. **Enable quantization**:
|
||||
```bash
|
||||
--quantize bnb.nf4 # 4-bit QLoRA
|
||||
```
|
||||
|
||||
2. **Reduce batch size**:
|
||||
```bash
|
||||
--train.micro_batch_size 1
|
||||
```
|
||||
|
||||
3. **Lower LoRA rank**:
|
||||
```bash
|
||||
--lora_r 4 # Instead of 8
|
||||
```
|
||||
|
||||
4. **Use FSDP** (multi-GPU):
|
||||
```bash
|
||||
litgpt finetune_lora meta-llama/Llama-2-7b-hf \
|
||||
--devices 4 # Use 4 GPUs with FSDP
|
||||
```
|
||||
|
||||
5. **Gradient checkpointing**:
|
||||
```bash
|
||||
--train.gradient_accumulation_iters 16
|
||||
```
|
||||
|
||||
## Data Format
|
||||
|
||||
LitGPT expects JSON data in instruction format:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"instruction": "What is the capital of France?",
|
||||
"input": "",
|
||||
"output": "The capital of France is Paris."
|
||||
},
|
||||
{
|
||||
"instruction": "Translate to Spanish:",
|
||||
"input": "Hello world",
|
||||
"output": "Hola mundo"
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
**Load custom data**:
|
||||
```bash
|
||||
litgpt finetune_lora meta-llama/Llama-2-7b-hf \
|
||||
--data JSON \
|
||||
--data.json_path data/my_dataset.json \
|
||||
--data.val_split_fraction 0.1 # 10% validation
|
||||
```
|
||||
|
||||
## Merge and Deploy
|
||||
|
||||
After fine-tuning, merge LoRA weights:
|
||||
|
||||
```bash
|
||||
litgpt merge_lora checkpoints/meta-llama/Llama-2-7b-hf/final_lora.pth
|
||||
```
|
||||
|
||||
Generate with merged model:
|
||||
|
||||
```bash
|
||||
litgpt generate checkpoints/meta-llama/Llama-2-7b-hf-merged/ \
|
||||
--prompt "What is machine learning?"
|
||||
```
|
||||
|
||||
Or serve via API:
|
||||
|
||||
```bash
|
||||
litgpt serve checkpoints/meta-llama/Llama-2-7b-hf-merged/
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- Configuration hub: `config_hub/finetune/`
|
||||
- Fine-tuning tutorial: `tutorials/finetune_*.md`
|
||||
- Memory guide: `tutorials/oom.md`
|
||||
@@ -0,0 +1,260 @@
|
||||
---
|
||||
name: mamba-architecture
|
||||
description: State-space model with O(n) complexity vs Transformers' O(n²). 5× faster inference, million-token sequences, no KV cache. Selective SSM with hardware-aware design. Mamba-1 (d_state=16) and Mamba-2 (d_state=128, multi-head). Models 130M-2.8B on HuggingFace.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Model Architecture, Mamba, State Space Models, SSM, Linear Complexity, Long Context, Efficient Inference, Hardware-Aware, Alternative To Transformers]
|
||||
dependencies: [mamba-ssm, torch, transformers, causal-conv1d]
|
||||
---
|
||||
|
||||
# Mamba - Selective State Space Models
|
||||
|
||||
## Quick start
|
||||
|
||||
Mamba is a state-space model architecture achieving O(n) linear complexity for sequence modeling.
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
# Install causal-conv1d (optional, for efficiency)
|
||||
pip install causal-conv1d>=1.4.0
|
||||
|
||||
# Install Mamba
|
||||
pip install mamba-ssm
|
||||
# Or both together
|
||||
pip install mamba-ssm[causal-conv1d]
|
||||
```
|
||||
|
||||
**Prerequisites**: Linux, NVIDIA GPU, PyTorch 1.12+, CUDA 11.6+
|
||||
|
||||
**Basic usage** (Mamba block):
|
||||
```python
|
||||
import torch
|
||||
from mamba_ssm import Mamba
|
||||
|
||||
batch, length, dim = 2, 64, 16
|
||||
x = torch.randn(batch, length, dim).to("cuda")
|
||||
|
||||
model = Mamba(
|
||||
d_model=dim, # Model dimension
|
||||
d_state=16, # SSM state dimension
|
||||
d_conv=4, # Conv1d kernel size
|
||||
expand=2 # Expansion factor
|
||||
).to("cuda")
|
||||
|
||||
y = model(x) # O(n) complexity!
|
||||
assert y.shape == x.shape
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Language model with Mamba-2
|
||||
|
||||
**Complete LM with generation**:
|
||||
```python
|
||||
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
|
||||
from mamba_ssm.models.config_mamba import MambaConfig
|
||||
import torch
|
||||
|
||||
# Configure Mamba-2 LM
|
||||
config = MambaConfig(
|
||||
d_model=1024, # Hidden dimension
|
||||
n_layer=24, # Number of layers
|
||||
vocab_size=50277, # Vocabulary size
|
||||
ssm_cfg=dict(
|
||||
layer="Mamba2", # Use Mamba-2
|
||||
d_state=128, # Larger state for Mamba-2
|
||||
headdim=64, # Head dimension
|
||||
ngroups=1 # Number of groups
|
||||
)
|
||||
)
|
||||
|
||||
model = MambaLMHeadModel(config, device="cuda", dtype=torch.float16)
|
||||
|
||||
# Generate text
|
||||
input_ids = torch.randint(0, 1000, (1, 20), device="cuda", dtype=torch.long)
|
||||
output = model.generate(
|
||||
input_ids=input_ids,
|
||||
max_length=100,
|
||||
temperature=0.7,
|
||||
top_p=0.9
|
||||
)
|
||||
```
|
||||
|
||||
### Workflow 2: Use pretrained Mamba models
|
||||
|
||||
**Load from HuggingFace**:
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
|
||||
|
||||
# Load pretrained model
|
||||
model_name = "state-spaces/mamba-2.8b"
|
||||
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") # Use compatible tokenizer
|
||||
model = MambaLMHeadModel.from_pretrained(model_name, device="cuda", dtype=torch.float16)
|
||||
|
||||
# Generate
|
||||
prompt = "The future of AI is"
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
|
||||
output_ids = model.generate(
|
||||
input_ids=input_ids,
|
||||
max_length=200,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
repetition_penalty=1.2
|
||||
)
|
||||
generated_text = tokenizer.decode(output_ids[0])
|
||||
print(generated_text)
|
||||
```
|
||||
|
||||
**Available models**:
|
||||
- `state-spaces/mamba-130m`
|
||||
- `state-spaces/mamba-370m`
|
||||
- `state-spaces/mamba-790m`
|
||||
- `state-spaces/mamba-1.4b`
|
||||
- `state-spaces/mamba-2.8b`
|
||||
|
||||
### Workflow 3: Mamba-1 vs Mamba-2
|
||||
|
||||
**Mamba-1** (smaller state):
|
||||
```python
|
||||
from mamba_ssm import Mamba
|
||||
|
||||
model = Mamba(
|
||||
d_model=256,
|
||||
d_state=16, # Smaller state dimension
|
||||
d_conv=4,
|
||||
expand=2
|
||||
).to("cuda")
|
||||
```
|
||||
|
||||
**Mamba-2** (multi-head, larger state):
|
||||
```python
|
||||
from mamba_ssm import Mamba2
|
||||
|
||||
model = Mamba2(
|
||||
d_model=256,
|
||||
d_state=128, # Larger state dimension
|
||||
d_conv=4,
|
||||
expand=2,
|
||||
headdim=64, # Head dimension for multi-head
|
||||
ngroups=1 # Parallel groups
|
||||
).to("cuda")
|
||||
```
|
||||
|
||||
**Key differences**:
|
||||
- **State size**: Mamba-1 (d_state=16) vs Mamba-2 (d_state=128)
|
||||
- **Architecture**: Mamba-2 has multi-head structure
|
||||
- **Normalization**: Mamba-2 uses RMSNorm
|
||||
- **Distributed**: Mamba-2 supports tensor parallelism
|
||||
|
||||
### Workflow 4: Benchmark vs Transformers
|
||||
|
||||
**Generation speed comparison**:
|
||||
```bash
|
||||
# Benchmark Mamba
|
||||
python benchmarks/benchmark_generation_mamba_simple.py \
|
||||
--model-name "state-spaces/mamba-2.8b" \
|
||||
--prompt "The future of machine learning is" \
|
||||
--topp 0.9 --temperature 0.7 --repetition-penalty 1.2
|
||||
|
||||
# Benchmark Transformer
|
||||
python benchmarks/benchmark_generation_mamba_simple.py \
|
||||
--model-name "EleutherAI/pythia-2.8b" \
|
||||
--prompt "The future of machine learning is" \
|
||||
--topp 0.9 --temperature 0.7 --repetition-penalty 1.2
|
||||
```
|
||||
|
||||
**Expected results**:
|
||||
- **Mamba**: 5× faster inference
|
||||
- **Memory**: No KV cache needed
|
||||
- **Scaling**: Linear with sequence length
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use Mamba when**:
|
||||
- Need long sequences (100K+ tokens)
|
||||
- Want faster inference than Transformers
|
||||
- Memory-constrained (no KV cache)
|
||||
- Building streaming applications
|
||||
- Linear scaling important
|
||||
|
||||
**Advantages**:
|
||||
- **O(n) complexity**: Linear vs quadratic
|
||||
- **5× faster inference**: No attention overhead
|
||||
- **No KV cache**: Lower memory usage
|
||||
- **Million-token sequences**: Hardware-efficient
|
||||
- **Streaming**: Constant memory per token
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **Transformers**: Need best-in-class performance, have compute
|
||||
- **RWKV**: Want RNN+Transformer hybrid
|
||||
- **RetNet**: Need retention-based architecture
|
||||
- **Hyena**: Want convolution-based approach
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: CUDA out of memory**
|
||||
|
||||
Reduce batch size or use gradient checkpointing:
|
||||
```python
|
||||
model = MambaLMHeadModel(config, device="cuda", dtype=torch.float16)
|
||||
model.gradient_checkpointing_enable() # Enable checkpointing
|
||||
```
|
||||
|
||||
**Issue: Slow installation**
|
||||
|
||||
Install binary wheels (not source):
|
||||
```bash
|
||||
pip install mamba-ssm --no-build-isolation
|
||||
```
|
||||
|
||||
**Issue: Missing causal-conv1d**
|
||||
|
||||
Install separately:
|
||||
```bash
|
||||
pip install causal-conv1d>=1.4.0
|
||||
```
|
||||
|
||||
**Issue: Model not loading from HuggingFace**
|
||||
|
||||
Use `MambaLMHeadModel.from_pretrained` (not `AutoModel`):
|
||||
```python
|
||||
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
|
||||
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b")
|
||||
```
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**Selective SSM**: See [references/selective-ssm.md](references/selective-ssm.md) for mathematical formulation, state-space equations, and how selectivity enables O(n) complexity.
|
||||
|
||||
**Mamba-2 architecture**: See [references/mamba2-details.md](references/mamba2-details.md) for multi-head structure, tensor parallelism, and distributed training setup.
|
||||
|
||||
**Performance optimization**: See [references/performance.md](references/performance.md) for hardware-aware design, CUDA kernels, and memory efficiency techniques.
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **GPU**: NVIDIA with CUDA 11.6+
|
||||
- **VRAM**:
|
||||
- 130M model: 2GB
|
||||
- 370M model: 4GB
|
||||
- 790M model: 8GB
|
||||
- 1.4B model: 14GB
|
||||
- 2.8B model: 28GB (FP16)
|
||||
- **Inference**: 5× faster than Transformers
|
||||
- **Memory**: No KV cache (lower than Transformers)
|
||||
|
||||
**Performance** (vs Transformers):
|
||||
- **Speed**: 5× faster inference
|
||||
- **Memory**: 50% less (no KV cache)
|
||||
- **Scaling**: Linear vs quadratic
|
||||
|
||||
## Resources
|
||||
|
||||
- Paper (Mamba-1): https://arxiv.org/abs/2312.00752 (Dec 2023)
|
||||
- Paper (Mamba-2): https://arxiv.org/abs/2405.21060 (May 2024)
|
||||
- GitHub: https://github.com/state-spaces/mamba ⭐ 13,000+
|
||||
- Models: https://huggingface.co/state-spaces
|
||||
- Docs: Repository README and wiki
|
||||
|
||||
|
||||
@@ -0,0 +1,206 @@
|
||||
# Mamba Architecture Details
|
||||
|
||||
## Selective State Space Mechanism
|
||||
|
||||
Mamba's core innovation is the **Selective SSM (S6)** layer that makes state space model parameters input-dependent.
|
||||
|
||||
### How S6 Works
|
||||
|
||||
**Traditional SSMs** (non-selective):
|
||||
```python
|
||||
# Fixed A, B, C matrices for all inputs
|
||||
h(t) = A * h(t-1) + B * x(t) # State update
|
||||
y(t) = C * h(t) # Output
|
||||
```
|
||||
|
||||
**Mamba's Selective SSM**:
|
||||
```python
|
||||
# Input-dependent parameters
|
||||
B(t) = Linear_B(x(t)) # Selection mechanism
|
||||
C(t) = Linear_C(x(t)) # Output projection
|
||||
Δ(t) = Linear_Δ(x(t)) # Discretization step
|
||||
|
||||
# Selective state update
|
||||
h(t) = discretize(A, Δ(t)) * h(t-1) + Δ(t) * B(t) * x(t)
|
||||
y(t) = C(t) * h(t)
|
||||
```
|
||||
|
||||
### Key Advantages
|
||||
|
||||
**1. Content-based reasoning**:
|
||||
- Can selectively remember or forget based on input
|
||||
- Addresses discrete modality weakness of traditional SSMs
|
||||
- Example: Remembers important tokens, forgets padding
|
||||
|
||||
**2. Input-dependent selection**:
|
||||
```python
|
||||
# Mamba decides per token what to remember
|
||||
if is_important(x(t)):
|
||||
Δ(t) = large_value # Keep in state
|
||||
else:
|
||||
Δ(t) = small_value # Forget quickly
|
||||
```
|
||||
|
||||
**3. No attention required**:
|
||||
- Replaces O(n²) attention with O(n) state updates
|
||||
- State dimension is constant (typically 16)
|
||||
|
||||
## Model Configuration
|
||||
|
||||
### Core Parameters
|
||||
|
||||
```python
|
||||
from mamba_ssm import Mamba
|
||||
|
||||
model = Mamba(
|
||||
d_model=256, # Hidden dimension (256, 512, 768, 1024, 2048)
|
||||
d_state=16, # SSM state dimension (fixed at 16 is optimal)
|
||||
d_conv=4, # Local convolution width (4 is standard)
|
||||
expand=2, # Expansion factor (1.5-2.0)
|
||||
dt_rank="auto", # Rank of Δ projection (auto = d_model / 16)
|
||||
dt_min=0.001, # Min Δ init (controls forgetting rate)
|
||||
dt_max=0.1, # Max Δ init
|
||||
dt_init="random", # Δ initialization (random, constant)
|
||||
dt_scale=1.0, # Δ scaling factor
|
||||
conv_bias=True, # Use bias in convolution
|
||||
bias=False # Use bias in linear projections
|
||||
)
|
||||
```
|
||||
|
||||
### Parameter Impact
|
||||
|
||||
**d_state** (SSM state dimension):
|
||||
- Standard: 16 (optimal from ablations)
|
||||
- Smaller (8): Faster but less capacity
|
||||
- Larger (32, 64): Minimal improvement, 2× slower
|
||||
|
||||
**expand** (block expansion):
|
||||
- Standard: 2.0
|
||||
- Range: 1.5-2.0
|
||||
- Controls inner dimension = expand * d_model
|
||||
|
||||
**d_conv** (convolution width):
|
||||
- Standard: 4
|
||||
- Local context window before SSM
|
||||
- Helps with positional information
|
||||
|
||||
**dt_rank** (Δ projection rank):
|
||||
- Auto: d_model / 16 (recommended)
|
||||
- Controls Δ parameter efficiency
|
||||
- Lower rank = more efficient but less expressive
|
||||
|
||||
## Mamba Block Structure
|
||||
|
||||
```python
|
||||
# Mamba block (replaces Transformer block)
|
||||
class MambaBlock(nn.Module):
|
||||
def __init__(self, d_model):
|
||||
self.norm = RMSNorm(d_model)
|
||||
self.mamba = Mamba(d_model, d_state=16, d_conv=4, expand=2)
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.mamba(self.norm(x)) # Residual
|
||||
|
||||
# Full model (stack of Mamba blocks)
|
||||
model = nn.Sequential(
|
||||
Embedding(...),
|
||||
*[MambaBlock(d_model) for _ in range(n_layers)],
|
||||
RMSNorm(d_model),
|
||||
LMHead(...)
|
||||
)
|
||||
```
|
||||
|
||||
**Key differences from Transformers**:
|
||||
- No multi-head attention (MHA)
|
||||
- No feedforward network (FFN)
|
||||
- Single Mamba layer per block
|
||||
- 2× more layers than equivalent Transformer
|
||||
|
||||
## Hardware-Aware Implementation
|
||||
|
||||
### Parallel Algorithm
|
||||
|
||||
Mamba uses a **scan-based parallel algorithm** for training:
|
||||
|
||||
```python
|
||||
# Parallel mode (training)
|
||||
# GPU kernel fuses operations
|
||||
y = parallel_scan(A, B, C, x) # O(n log n) parallel
|
||||
|
||||
# Sequential mode (inference)
|
||||
# Constant memory RNN-style
|
||||
h = 0
|
||||
for x_t in sequence:
|
||||
h = A*h + B*x_t
|
||||
y_t = C*h
|
||||
```
|
||||
|
||||
### Memory Efficiency
|
||||
|
||||
**Training**:
|
||||
- Recomputes activations in backward pass
|
||||
- Similar to FlashAttention strategy
|
||||
- Memory: O(batch_size * seq_len * d_model)
|
||||
|
||||
**Inference**:
|
||||
- RNN-style sequential processing
|
||||
- State size: O(d_model * d_state) = constant
|
||||
- No KV cache needed (huge advantage!)
|
||||
|
||||
### CUDA Kernel Optimizations
|
||||
|
||||
```python
|
||||
# Fused kernel operations
|
||||
- Discretization (continuous → discrete A, B)
|
||||
- SSM recurrence (parallel scan)
|
||||
- Convolution (efficient 1D conv)
|
||||
- All in single GPU kernel
|
||||
```
|
||||
|
||||
## Layer Count Scaling
|
||||
|
||||
Mamba models use **2× layers** compared to Transformers:
|
||||
|
||||
| Model | d_model | n_layers | Params |
|
||||
|-------|---------|----------|--------|
|
||||
| Mamba-130M | 768 | 24 | 130M |
|
||||
| Mamba-370M | 1024 | 48 | 370M |
|
||||
| Mamba-790M | 1536 | 48 | 790M |
|
||||
| Mamba-1.4B | 2048 | 48 | 1.4B |
|
||||
| Mamba-2.8B | 2560 | 64 | 2.8B |
|
||||
|
||||
**Why 2× layers?**
|
||||
- Mamba blocks are simpler (no MHA, no FFN)
|
||||
- ~50% fewer parameters per layer
|
||||
- Doubling layers matches compute budget
|
||||
|
||||
## Initialization Strategy
|
||||
|
||||
```python
|
||||
# Δ (discretization step) initialization
|
||||
dt_init_floor = 1e-4
|
||||
dt = torch.exp(
|
||||
torch.rand(d_inner) * (math.log(dt_max) - math.log(dt_min))
|
||||
+ math.log(dt_min)
|
||||
).clamp(min=dt_init_floor)
|
||||
|
||||
# A (state transition) initialization
|
||||
A = -torch.exp(torch.rand(d_inner, d_state)) # Negative for stability
|
||||
|
||||
# B, C (input/output) initialization
|
||||
B = torch.randn(d_inner, d_state)
|
||||
C = torch.randn(d_inner, d_state)
|
||||
```
|
||||
|
||||
**Critical for stability**:
|
||||
- A must be negative (exponential decay)
|
||||
- Δ in range [dt_min, dt_max]
|
||||
- Random initialization helps diversity
|
||||
|
||||
## Resources
|
||||
|
||||
- Paper: https://arxiv.org/abs/2312.00752 (Mamba-1)
|
||||
- Paper: https://arxiv.org/abs/2405.21060 (Mamba-2)
|
||||
- GitHub: https://github.com/state-spaces/mamba
|
||||
- Models: https://huggingface.co/state-spaces
|
||||
- CUDA kernels: https://github.com/state-spaces/mamba/tree/main/csrc
|
||||
@@ -0,0 +1,255 @@
|
||||
# Mamba Performance Benchmarks
|
||||
|
||||
## Inference Speed Comparison
|
||||
|
||||
### Throughput (tokens/sec)
|
||||
|
||||
**Mamba-1.4B vs Transformer-1.3B** on single A100 80GB:
|
||||
|
||||
| Sequence Length | Mamba-1.4B | Transformer-1.3B | Speedup |
|
||||
|----------------|------------|------------------|---------|
|
||||
| 512 | 8,300 | 6,200 | 1.3× |
|
||||
| 1024 | 7,800 | 4,100 | 1.9× |
|
||||
| 2048 | 7,200 | 2,300 | 3.1× |
|
||||
| 4096 | 6,800 | 1,200 | 5.7× |
|
||||
| 8192 | 6,400 | 600 | **10.7×** |
|
||||
| 16384 | 6,100 | OOM | ∞ |
|
||||
|
||||
**Key insight**: Speedup grows with sequence length (Mamba O(n) vs Transformer O(n²))
|
||||
|
||||
### Latency (ms per token)
|
||||
|
||||
**Generation latency** (batch size 1, autoregressive):
|
||||
|
||||
| Model | First Token | Per Token | 100 Tokens Total |
|
||||
|-------|-------------|-----------|------------------|
|
||||
| Mamba-130M | 3 ms | 0.8 ms | 83 ms |
|
||||
| Transformer-130M | 5 ms | 1.2 ms | 125 ms |
|
||||
| Mamba-1.4B | 12 ms | 3.2 ms | 332 ms |
|
||||
| Transformer-1.3B | 18 ms | 8.5 ms | 868 ms |
|
||||
| Mamba-2.8B | 20 ms | 6.1 ms | 631 ms |
|
||||
| Transformer-2.7B | 35 ms | 18.2 ms | 1855 ms |
|
||||
|
||||
**Mamba advantage**: Constant per-token latency regardless of context length
|
||||
|
||||
## Memory Usage
|
||||
|
||||
### Training Memory (BF16, per GPU)
|
||||
|
||||
**Mamba-1.4B** training memory breakdown:
|
||||
|
||||
| Sequence Length | Activations | Gradients | Optimizer | Total | vs Transformer |
|
||||
|----------------|-------------|-----------|-----------|-------|----------------|
|
||||
| 512 | 2.1 GB | 3.2 GB | 11.2 GB | 16.5 GB | 0.9× |
|
||||
| 1024 | 3.8 GB | 3.2 GB | 11.2 GB | 18.2 GB | 0.6× |
|
||||
| 2048 | 7.2 GB | 3.2 GB | 11.2 GB | 21.6 GB | 0.4× |
|
||||
| 4096 | 14.1 GB | 3.2 GB | 11.2 GB | 28.5 GB | 0.25× |
|
||||
| 8192 | 28.0 GB | 3.2 GB | 11.2 GB | 42.4 GB | 0.15× |
|
||||
|
||||
**Note**: Transformer OOMs at 8K sequence length on 40GB A100
|
||||
|
||||
### Inference Memory (FP16, batch size 1)
|
||||
|
||||
| Model | KV Cache (8K ctx) | State (Mamba) | Ratio |
|
||||
|-------|------------------|---------------|-------|
|
||||
| 130M | 2.1 GB | 0 MB | ∞ |
|
||||
| 370M | 5.2 GB | 0 MB | ∞ |
|
||||
| 1.4B | 19.7 GB | 0 MB | ∞ |
|
||||
| 2.8B | 38.4 GB | 0 MB | ∞ |
|
||||
|
||||
**Mamba stores no KV cache** - constant memory per token!
|
||||
|
||||
Actual Mamba state size:
|
||||
- 130M: ~3 MB (d_model × d_state × n_layers = 768 × 16 × 24)
|
||||
- 2.8B: ~13 MB (2560 × 16 × 64)
|
||||
|
||||
## Language Modeling Benchmarks
|
||||
|
||||
### Perplexity on Common Datasets
|
||||
|
||||
**Models trained on The Pile (300B tokens)**:
|
||||
|
||||
| Model | Params | Pile (val) | WikiText-103 | C4 | Lambada |
|
||||
|-------|--------|------------|--------------|-----|---------|
|
||||
| Pythia | 160M | 29.6 | 28.4 | 23.1 | 51.2 |
|
||||
| **Mamba** | **130M** | **28.1** | **26.7** | **21.8** | **48.3** |
|
||||
| Pythia | 410M | 18.3 | 17.6 | 16.2 | 32.1 |
|
||||
| **Mamba** | **370M** | **16.7** | **16.2** | **15.1** | **28.4** |
|
||||
| Pythia | 1.4B | 10.8 | 10.2 | 11.3 | 15.2 |
|
||||
| **Mamba** | **1.4B** | **9.1** | **9.6** | **10.1** | **12.8** |
|
||||
| Pythia | 2.8B | 8.3 | 7.9 | 9.2 | 10.6 |
|
||||
| **Mamba** | **2.8B** | **7.4** | **7.2** | **8.3** | **9.1** |
|
||||
|
||||
**Mamba consistently outperforms** Transformers of similar size by 10-20%
|
||||
|
||||
### Zero-Shot Task Performance
|
||||
|
||||
**Mamba-2.8B vs Transformer-2.7B** on common benchmarks:
|
||||
|
||||
| Task | Mamba-2.8B | Transformer-2.7B | Delta |
|
||||
|------|------------|------------------|-------|
|
||||
| HellaSwag | 61.3 | 58.7 | +2.6 |
|
||||
| PIQA | 78.1 | 76.4 | +1.7 |
|
||||
| ARC-Easy | 68.2 | 65.9 | +2.3 |
|
||||
| ARC-Challenge | 42.7 | 40.1 | +2.6 |
|
||||
| WinoGrande | 64.8 | 62.3 | +2.5 |
|
||||
| OpenBookQA | 43.2 | 41.8 | +1.4 |
|
||||
| BoolQ | 71.4 | 68.2 | +3.2 |
|
||||
| MMLU (5-shot) | 35.2 | 33.8 | +1.4 |
|
||||
|
||||
**Average improvement**: +2.2 points across benchmarks
|
||||
|
||||
## Audio Modeling Benchmarks
|
||||
|
||||
### SC09 (Speech Commands)
|
||||
|
||||
**Task**: Audio classification (10 classes)
|
||||
|
||||
| Model | Params | Accuracy | Inference (ms) |
|
||||
|-------|--------|----------|----------------|
|
||||
| Transformer | 8.2M | 96.2% | 18 ms |
|
||||
| S4 | 6.1M | 97.1% | 8 ms |
|
||||
| **Mamba** | **6.3M** | **98.4%** | **6 ms** |
|
||||
|
||||
### LJSpeech (Speech Generation)
|
||||
|
||||
**Task**: Text-to-speech quality (MOS score)
|
||||
|
||||
| Model | Params | MOS ↑ | RTF ↓ |
|
||||
|-------|--------|-------|-------|
|
||||
| Transformer | 12M | 3.82 | 0.45 |
|
||||
| Conformer | 11M | 3.91 | 0.38 |
|
||||
| **Mamba** | **10M** | **4.03** | **0.21** |
|
||||
|
||||
**RTF** (Real-Time Factor): Lower is better (0.21 = 5× faster than real-time)
|
||||
|
||||
## Genomics Benchmarks
|
||||
|
||||
### Human Reference Genome (HG38)
|
||||
|
||||
**Task**: Next nucleotide prediction
|
||||
|
||||
| Model | Context Length | Perplexity | Throughput |
|
||||
|-------|----------------|------------|------------|
|
||||
| Transformer | 1024 | 3.21 | 1,200 bp/s |
|
||||
| Hyena | 32768 | 2.87 | 8,500 bp/s |
|
||||
| **Mamba** | **1M** | **2.14** | **45,000 bp/s** |
|
||||
|
||||
**Mamba handles million-length sequences** efficiently
|
||||
|
||||
## Scaling Laws
|
||||
|
||||
### Compute-Optimal Training
|
||||
|
||||
**FLOPs vs perplexity** (The Pile validation):
|
||||
|
||||
| Model Size | Training FLOPs | Mamba Perplexity | Transformer Perplexity |
|
||||
|------------|----------------|------------------|------------------------|
|
||||
| 130M | 6e19 | 28.1 | 29.6 |
|
||||
| 370M | 3e20 | 16.7 | 18.3 |
|
||||
| 790M | 8e20 | 12.3 | 13.9 |
|
||||
| 1.4B | 2e21 | 9.1 | 10.8 |
|
||||
| 2.8B | 6e21 | 7.4 | 8.3 |
|
||||
|
||||
**Scaling coefficient**: Mamba achieves same perplexity as Transformer with **0.8×** compute
|
||||
|
||||
### Parameter Efficiency
|
||||
|
||||
**Perplexity 10.0 target** on The Pile:
|
||||
|
||||
| Model Type | Parameters Needed | Memory (inference) |
|
||||
|------------|-------------------|-------------------|
|
||||
| Transformer | 1.6B | 3.2 GB |
|
||||
| **Mamba** | **1.1B** | **2.2 GB** |
|
||||
|
||||
**Mamba needs ~30% fewer parameters** for same performance
|
||||
|
||||
## Long-Range Arena (LRA)
|
||||
|
||||
**Task**: Long-context understanding benchmarks
|
||||
|
||||
| Task | Length | Transformer | S4 | Mamba |
|
||||
|------|--------|-------------|-----|-------|
|
||||
| ListOps | 2K | 36.4% | 59.6% | **61.2%** |
|
||||
| Text | 4K | 64.3% | 86.8% | **88.1%** |
|
||||
| Retrieval | 4K | 57.5% | 90.9% | **92.3%** |
|
||||
| Image | 1K | 42.4% | 88.7% | **89.4%** |
|
||||
| PathFinder | 1K | 71.4% | 86.1% | **87.8%** |
|
||||
| Path-X | 16K | OOM | 88.3% | **91.2%** |
|
||||
|
||||
**Average**: Mamba 85.0%, S4 83.4%, Transformer 54.4%
|
||||
|
||||
## Training Throughput
|
||||
|
||||
### Tokens/sec During Training
|
||||
|
||||
**8× A100 80GB** cluster, BF16, different sequence lengths:
|
||||
|
||||
| Model | Seq Len 512 | Seq Len 2K | Seq Len 8K | Seq Len 32K |
|
||||
|-------|-------------|------------|------------|-------------|
|
||||
| Transformer-1.3B | 180K | 52K | OOM | OOM |
|
||||
| **Mamba-1.4B** | **195K** | **158K** | **121K** | **89K** |
|
||||
| Transformer-2.7B | 92K | 26K | OOM | OOM |
|
||||
| **Mamba-2.8B** | **98K** | **81K** | **62K** | **45K** |
|
||||
|
||||
**Mamba scales to longer sequences** without OOM
|
||||
|
||||
## Hardware Utilization
|
||||
|
||||
### GPU Memory Bandwidth
|
||||
|
||||
**Mamba-1.4B** inference on different GPUs:
|
||||
|
||||
| GPU | Memory BW | Tokens/sec | Efficiency |
|
||||
|-----|-----------|------------|------------|
|
||||
| A100 80GB | 2.0 TB/s | 6,800 | 85% |
|
||||
| A100 40GB | 1.6 TB/s | 5,400 | 84% |
|
||||
| V100 32GB | 900 GB/s | 3,100 | 86% |
|
||||
| RTX 4090 | 1.0 TB/s | 3,600 | 90% |
|
||||
|
||||
**High efficiency**: Mamba is memory-bandwidth bound (good!)
|
||||
|
||||
### Multi-GPU Scaling
|
||||
|
||||
**Mamba-2.8B** training throughput:
|
||||
|
||||
| GPUs | Tokens/sec | Scaling Efficiency |
|
||||
|------|------------|-------------------|
|
||||
| 1× A100 | 12,300 | 100% |
|
||||
| 2× A100 | 23,800 | 97% |
|
||||
| 4× A100 | 46,100 | 94% |
|
||||
| 8× A100 | 89,400 | 91% |
|
||||
| 16× A100 | 172,000 | 88% |
|
||||
|
||||
**Near-linear scaling** up to 16 GPUs
|
||||
|
||||
## Cost Analysis
|
||||
|
||||
### Training Cost (USD)
|
||||
|
||||
**Training to The Pile perplexity 10.0** on cloud GPUs:
|
||||
|
||||
| Model | Cloud GPUs | Hours | Cost (A100) | Cost (H100) |
|
||||
|-------|------------|-------|-------------|-------------|
|
||||
| Transformer-1.6B | 8× A100 | 280 | $8,400 | $4,200 |
|
||||
| **Mamba-1.1B** | **8× A100** | **180** | **$5,400** | **$2,700** |
|
||||
|
||||
**Savings**: 36% cost reduction vs Transformer
|
||||
|
||||
### Inference Cost (USD/million tokens)
|
||||
|
||||
**API-style inference** (batch size 1, 2K context):
|
||||
|
||||
| Model | Latency | Cost/M tokens | Quality (perplexity) |
|
||||
|-------|---------|---------------|---------------------|
|
||||
| Transformer-1.3B | 8.5 ms/tok | $0.42 | 10.8 |
|
||||
| **Mamba-1.4B** | **3.2 ms/tok** | **$0.18** | **9.1** |
|
||||
|
||||
**Mamba provides**: 2.6× faster, 57% cheaper, better quality
|
||||
|
||||
## Resources
|
||||
|
||||
- Benchmarks code: https://github.com/state-spaces/mamba/tree/main/benchmarks
|
||||
- Paper (Mamba-1): https://arxiv.org/abs/2312.00752 (Section 4: Experiments)
|
||||
- Paper (Mamba-2): https://arxiv.org/abs/2405.21060 (Section 5: Experiments)
|
||||
- Pretrained models: https://huggingface.co/state-spaces
|
||||
@@ -0,0 +1,388 @@
|
||||
# Mamba Training Guide
|
||||
|
||||
## Training from Scratch
|
||||
|
||||
### Setup Environment
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
pip install torch>=1.12.0 --extra-index-url https://download.pytorch.org/whl/cu116
|
||||
pip install packaging ninja
|
||||
pip install causal-conv1d>=1.1.0
|
||||
pip install mamba-ssm
|
||||
|
||||
# Verify CUDA
|
||||
python -c "import torch; print(torch.cuda.is_available())"
|
||||
```
|
||||
|
||||
### Basic Training Loop
|
||||
|
||||
```python
|
||||
import torch
|
||||
from mamba_ssm import Mamba
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# Model setup
|
||||
model = Mamba(
|
||||
d_model=512,
|
||||
d_state=16,
|
||||
d_conv=4,
|
||||
expand=2
|
||||
).cuda()
|
||||
|
||||
# Optimizer (same as GPT)
|
||||
optimizer = torch.optim.AdamW(
|
||||
model.parameters(),
|
||||
lr=6e-4,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=0.1
|
||||
)
|
||||
|
||||
# Training loop
|
||||
for batch in dataloader:
|
||||
inputs, targets = batch
|
||||
inputs, targets = inputs.cuda(), targets.cuda()
|
||||
|
||||
# Forward
|
||||
logits = model(inputs)
|
||||
loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))
|
||||
|
||||
# Backward
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
## Distributed Training
|
||||
|
||||
### Single-Node Multi-GPU (DDP)
|
||||
|
||||
```python
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
# Initialize process group
|
||||
dist.init_process_group("nccl")
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
# Wrap model
|
||||
model = Mamba(...).cuda()
|
||||
model = DDP(model, device_ids=[local_rank])
|
||||
|
||||
# Train
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=6e-4)
|
||||
for batch in dataloader:
|
||||
loss = compute_loss(model, batch)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Launch**:
|
||||
```bash
|
||||
torchrun --nproc_per_node=8 train.py
|
||||
```
|
||||
|
||||
### Multi-Node Training
|
||||
|
||||
```bash
|
||||
# Node 0 (master)
|
||||
torchrun --nproc_per_node=8 \
|
||||
--nnodes=4 --node_rank=0 \
|
||||
--master_addr=$MASTER_ADDR --master_port=29500 \
|
||||
train.py
|
||||
|
||||
# Node 1-3 (workers)
|
||||
torchrun --nproc_per_node=8 \
|
||||
--nnodes=4 --node_rank=$NODE_RANK \
|
||||
--master_addr=$MASTER_ADDR --master_port=29500 \
|
||||
train.py
|
||||
```
|
||||
|
||||
## Mixed Precision Training
|
||||
|
||||
### BF16 (Recommended)
|
||||
|
||||
```python
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
|
||||
# BF16 (no scaler needed on A100/H100)
|
||||
for batch in dataloader:
|
||||
with autocast(dtype=torch.bfloat16):
|
||||
logits = model(inputs)
|
||||
loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
### FP16 (with gradient scaling)
|
||||
|
||||
```python
|
||||
scaler = GradScaler()
|
||||
|
||||
for batch in dataloader:
|
||||
with autocast(dtype=torch.float16):
|
||||
logits = model(inputs)
|
||||
loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))
|
||||
|
||||
optimizer.zero_grad()
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
```
|
||||
|
||||
## Hyperparameter Recommendations
|
||||
|
||||
### Learning Rate Schedule
|
||||
|
||||
```python
|
||||
# Cosine decay with warmup (GPT-3 style)
|
||||
def get_lr(it, warmup_iters=2000, lr_decay_iters=600000):
|
||||
max_lr = 6e-4
|
||||
min_lr = 6e-5
|
||||
|
||||
# Warmup
|
||||
if it < warmup_iters:
|
||||
return max_lr * it / warmup_iters
|
||||
|
||||
# Decay
|
||||
if it > lr_decay_iters:
|
||||
return min_lr
|
||||
|
||||
# Cosine
|
||||
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
|
||||
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
|
||||
return min_lr + coeff * (max_lr - min_lr)
|
||||
|
||||
# Apply in training loop
|
||||
for it, batch in enumerate(dataloader):
|
||||
lr = get_lr(it)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
```
|
||||
|
||||
### Batch Size Recommendations
|
||||
|
||||
| Model Size | Per-GPU Batch | Gradient Accum | Effective Batch | GPUs |
|
||||
|------------|---------------|----------------|-----------------|------|
|
||||
| 130M | 32 | 4 | 1024 | 8 |
|
||||
| 370M | 16 | 8 | 1024 | 8 |
|
||||
| 790M | 8 | 8 | 512 | 8 |
|
||||
| 1.4B | 4 | 16 | 512 | 8 |
|
||||
| 2.8B | 2 | 16 | 256 | 8 |
|
||||
|
||||
```python
|
||||
# Gradient accumulation
|
||||
accumulation_steps = 8
|
||||
optimizer.zero_grad()
|
||||
|
||||
for i, batch in enumerate(dataloader):
|
||||
loss = compute_loss(model, batch) / accumulation_steps
|
||||
loss.backward()
|
||||
|
||||
if (i + 1) % accumulation_steps == 0:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
### Optimizer Configuration
|
||||
|
||||
```python
|
||||
# AdamW (recommended)
|
||||
optimizer = torch.optim.AdamW(
|
||||
model.parameters(),
|
||||
lr=6e-4, # Peak learning rate
|
||||
betas=(0.9, 0.95), # Standard for LLMs
|
||||
eps=1e-8,
|
||||
weight_decay=0.1 # Important for generalization
|
||||
)
|
||||
|
||||
# Weight decay exemptions (optional)
|
||||
decay = set()
|
||||
no_decay = set()
|
||||
for name, param in model.named_parameters():
|
||||
if 'norm' in name or 'bias' in name:
|
||||
no_decay.add(param)
|
||||
else:
|
||||
decay.add(param)
|
||||
|
||||
optimizer = torch.optim.AdamW([
|
||||
{'params': list(decay), 'weight_decay': 0.1},
|
||||
{'params': list(no_decay), 'weight_decay': 0.0}
|
||||
], lr=6e-4, betas=(0.9, 0.95))
|
||||
```
|
||||
|
||||
## Memory Optimization
|
||||
|
||||
### Gradient Checkpointing
|
||||
|
||||
```python
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
class MambaBlock(nn.Module):
|
||||
def __init__(self, d_model, use_checkpoint=False):
|
||||
super().__init__()
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.norm = RMSNorm(d_model)
|
||||
self.mamba = Mamba(d_model)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_checkpoint and self.training:
|
||||
return x + checkpoint(self._forward, x, use_reentrant=False)
|
||||
return x + self._forward(x)
|
||||
|
||||
def _forward(self, x):
|
||||
return self.mamba(self.norm(x))
|
||||
|
||||
# Enable for training
|
||||
model = MambaLM(use_checkpoint=True)
|
||||
```
|
||||
|
||||
**Memory savings**: ~30-40% with minimal speed impact
|
||||
|
||||
### Flash Attention Integration
|
||||
|
||||
Mamba's CUDA kernels already use flash-attention-style optimizations:
|
||||
- Fused operations in single kernel
|
||||
- Recomputation in backward pass
|
||||
- No intermediate activation storage
|
||||
|
||||
## Long Context Training
|
||||
|
||||
### Sequence Length Progression
|
||||
|
||||
```python
|
||||
# Start short, increase gradually
|
||||
training_stages = [
|
||||
{'seq_len': 512, 'iters': 50000},
|
||||
{'seq_len': 1024, 'iters': 100000},
|
||||
{'seq_len': 2048, 'iters': 150000},
|
||||
{'seq_len': 4096, 'iters': 200000},
|
||||
]
|
||||
|
||||
for stage in training_stages:
|
||||
dataloader = create_dataloader(seq_len=stage['seq_len'])
|
||||
train(model, dataloader, max_iters=stage['iters'])
|
||||
```
|
||||
|
||||
### Memory Requirements (Batch Size 1)
|
||||
|
||||
| Sequence Length | 130M Model | 370M Model | 1.4B Model |
|
||||
|----------------|------------|------------|------------|
|
||||
| 2K | 4 GB | 8 GB | 24 GB |
|
||||
| 4K | 5 GB | 10 GB | 32 GB |
|
||||
| 8K | 6 GB | 14 GB | 48 GB |
|
||||
| 16K | 8 GB | 20 GB | 64 GB |
|
||||
| 32K | 12 GB | 32 GB | 96 GB |
|
||||
|
||||
**Mamba advantage**: Memory grows **linearly**, Transformers grow **quadratically**
|
||||
|
||||
## Common Training Issues
|
||||
|
||||
### Issue: OOM during training
|
||||
|
||||
**Solution 1**: Reduce batch size
|
||||
```python
|
||||
per_gpu_batch = 8 # Reduce from 16
|
||||
gradient_accumulation = 8 # Increase from 4
|
||||
```
|
||||
|
||||
**Solution 2**: Enable gradient checkpointing
|
||||
```python
|
||||
model = MambaLM(use_checkpoint=True)
|
||||
```
|
||||
|
||||
**Solution 3**: Use smaller sequence length
|
||||
```python
|
||||
seq_len = 1024 # Reduce from 2048
|
||||
```
|
||||
|
||||
### Issue: Training unstable (loss spikes)
|
||||
|
||||
**Solution 1**: Check gradient norm
|
||||
```python
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
print(f"Grad norm: {grad_norm}") # Should be < 10
|
||||
```
|
||||
|
||||
**Solution 2**: Lower learning rate
|
||||
```python
|
||||
max_lr = 3e-4 # Reduce from 6e-4
|
||||
```
|
||||
|
||||
**Solution 3**: Check Δ initialization
|
||||
```python
|
||||
# Ensure dt_min, dt_max are reasonable
|
||||
model = Mamba(
|
||||
d_model=512,
|
||||
dt_min=0.001, # Not too small
|
||||
dt_max=0.1 # Not too large
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Slow training speed
|
||||
|
||||
**Solution 1**: Verify CUDA kernels installed
|
||||
```python
|
||||
import mamba_ssm
|
||||
print(mamba_ssm.__version__) # Should have CUDA kernels
|
||||
```
|
||||
|
||||
**Solution 2**: Use BF16 on A100/H100
|
||||
```python
|
||||
with autocast(dtype=torch.bfloat16): # Faster than FP16
|
||||
loss = model(inputs)
|
||||
```
|
||||
|
||||
**Solution 3**: Increase batch size if possible
|
||||
```python
|
||||
per_gpu_batch = 16 # Increase from 8 (better GPU utilization)
|
||||
```
|
||||
|
||||
## Checkpointing
|
||||
|
||||
### Save/Load Model
|
||||
|
||||
```python
|
||||
# Save
|
||||
checkpoint = {
|
||||
'model': model.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'iter': iteration,
|
||||
'config': model_config
|
||||
}
|
||||
torch.save(checkpoint, f'checkpoint_{iteration}.pt')
|
||||
|
||||
# Load
|
||||
checkpoint = torch.load('checkpoint_100000.pt')
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
iteration = checkpoint['iter']
|
||||
```
|
||||
|
||||
### Best Practices
|
||||
|
||||
```python
|
||||
# Save every N iterations
|
||||
if iteration % save_interval == 0:
|
||||
save_checkpoint(model, optimizer, iteration)
|
||||
|
||||
# Keep only last K checkpoints
|
||||
checkpoints = sorted(glob.glob('checkpoint_*.pt'))
|
||||
if len(checkpoints) > keep_last:
|
||||
for ckpt in checkpoints[:-keep_last]:
|
||||
os.remove(ckpt)
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- Training code: https://github.com/state-spaces/mamba/tree/main/benchmarks
|
||||
- Pretrained models: https://huggingface.co/state-spaces
|
||||
- CUDA installation: https://github.com/state-spaces/mamba#installation
|
||||
@@ -0,0 +1,290 @@
|
||||
---
|
||||
name: nanogpt
|
||||
description: Educational GPT implementation in ~300 lines. Reproduces GPT-2 (124M) on OpenWebText. Clean, hackable code for learning transformers. By Andrej Karpathy. Perfect for understanding GPT architecture from scratch. Train on Shakespeare (CPU) or OpenWebText (multi-GPU).
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Model Architecture, NanoGPT, GPT-2, Educational, Andrej Karpathy, Transformer, Minimalist, From Scratch, Training]
|
||||
dependencies: [torch, transformers, datasets, tiktoken, wandb]
|
||||
---
|
||||
|
||||
# nanoGPT - Minimalist GPT Training
|
||||
|
||||
## Quick start
|
||||
|
||||
nanoGPT is a simplified GPT implementation designed for learning and experimentation.
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install torch numpy transformers datasets tiktoken wandb tqdm
|
||||
```
|
||||
|
||||
**Train on Shakespeare** (CPU-friendly):
|
||||
```bash
|
||||
# Prepare data
|
||||
python data/shakespeare_char/prepare.py
|
||||
|
||||
# Train (5 minutes on CPU)
|
||||
python train.py config/train_shakespeare_char.py
|
||||
|
||||
# Generate text
|
||||
python sample.py --out_dir=out-shakespeare-char
|
||||
```
|
||||
|
||||
**Output**:
|
||||
```
|
||||
ROMEO:
|
||||
What say'st thou? Shall I speak, and be a man?
|
||||
|
||||
JULIET:
|
||||
I am afeard, and yet I'll speak; for thou art
|
||||
One that hath been a man, and yet I know not
|
||||
What thou art.
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Character-level Shakespeare
|
||||
|
||||
**Complete training pipeline**:
|
||||
```bash
|
||||
# Step 1: Prepare data (creates train.bin, val.bin)
|
||||
python data/shakespeare_char/prepare.py
|
||||
|
||||
# Step 2: Train small model
|
||||
python train.py config/train_shakespeare_char.py
|
||||
|
||||
# Step 3: Generate text
|
||||
python sample.py --out_dir=out-shakespeare-char
|
||||
```
|
||||
|
||||
**Config** (`config/train_shakespeare_char.py`):
|
||||
```python
|
||||
# Model config
|
||||
n_layer = 6 # 6 transformer layers
|
||||
n_head = 6 # 6 attention heads
|
||||
n_embd = 384 # 384-dim embeddings
|
||||
block_size = 256 # 256 char context
|
||||
|
||||
# Training config
|
||||
batch_size = 64
|
||||
learning_rate = 1e-3
|
||||
max_iters = 5000
|
||||
eval_interval = 500
|
||||
|
||||
# Hardware
|
||||
device = 'cpu' # Or 'cuda'
|
||||
compile = False # Set True for PyTorch 2.0
|
||||
```
|
||||
|
||||
**Training time**: ~5 minutes (CPU), ~1 minute (GPU)
|
||||
|
||||
### Workflow 2: Reproduce GPT-2 (124M)
|
||||
|
||||
**Multi-GPU training on OpenWebText**:
|
||||
```bash
|
||||
# Step 1: Prepare OpenWebText (takes ~1 hour)
|
||||
python data/openwebtext/prepare.py
|
||||
|
||||
# Step 2: Train GPT-2 124M with DDP (8 GPUs)
|
||||
torchrun --standalone --nproc_per_node=8 \
|
||||
train.py config/train_gpt2.py
|
||||
|
||||
# Step 3: Sample from trained model
|
||||
python sample.py --out_dir=out
|
||||
```
|
||||
|
||||
**Config** (`config/train_gpt2.py`):
|
||||
```python
|
||||
# GPT-2 (124M) architecture
|
||||
n_layer = 12
|
||||
n_head = 12
|
||||
n_embd = 768
|
||||
block_size = 1024
|
||||
dropout = 0.0
|
||||
|
||||
# Training
|
||||
batch_size = 12
|
||||
gradient_accumulation_steps = 5 * 8 # Total batch ~0.5M tokens
|
||||
learning_rate = 6e-4
|
||||
max_iters = 600000
|
||||
lr_decay_iters = 600000
|
||||
|
||||
# System
|
||||
compile = True # PyTorch 2.0
|
||||
```
|
||||
|
||||
**Training time**: ~4 days (8× A100)
|
||||
|
||||
### Workflow 3: Fine-tune pretrained GPT-2
|
||||
|
||||
**Start from OpenAI checkpoint**:
|
||||
```python
|
||||
# In train.py or config
|
||||
init_from = 'gpt2' # Options: gpt2, gpt2-medium, gpt2-large, gpt2-xl
|
||||
|
||||
# Model loads OpenAI weights automatically
|
||||
python train.py config/finetune_shakespeare.py
|
||||
```
|
||||
|
||||
**Example config** (`config/finetune_shakespeare.py`):
|
||||
```python
|
||||
# Start from GPT-2
|
||||
init_from = 'gpt2'
|
||||
|
||||
# Dataset
|
||||
dataset = 'shakespeare_char'
|
||||
batch_size = 1
|
||||
block_size = 1024
|
||||
|
||||
# Fine-tuning
|
||||
learning_rate = 3e-5 # Lower LR for fine-tuning
|
||||
max_iters = 2000
|
||||
warmup_iters = 100
|
||||
|
||||
# Regularization
|
||||
weight_decay = 1e-1
|
||||
```
|
||||
|
||||
### Workflow 4: Custom dataset
|
||||
|
||||
**Train on your own text**:
|
||||
```python
|
||||
# data/custom/prepare.py
|
||||
import numpy as np
|
||||
|
||||
# Load your data
|
||||
with open('my_data.txt', 'r') as f:
|
||||
text = f.read()
|
||||
|
||||
# Create character mappings
|
||||
chars = sorted(list(set(text)))
|
||||
stoi = {ch: i for i, ch in enumerate(chars)}
|
||||
itos = {i: ch for i, ch in enumerate(chars)}
|
||||
|
||||
# Tokenize
|
||||
data = np.array([stoi[ch] for ch in text], dtype=np.uint16)
|
||||
|
||||
# Split train/val
|
||||
n = len(data)
|
||||
train_data = data[:int(n*0.9)]
|
||||
val_data = data[int(n*0.9):]
|
||||
|
||||
# Save
|
||||
train_data.tofile('data/custom/train.bin')
|
||||
val_data.tofile('data/custom/val.bin')
|
||||
```
|
||||
|
||||
**Train**:
|
||||
```bash
|
||||
python data/custom/prepare.py
|
||||
python train.py --dataset=custom
|
||||
```
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use nanoGPT when**:
|
||||
- Learning how GPT works
|
||||
- Experimenting with transformer variants
|
||||
- Teaching/education purposes
|
||||
- Quick prototyping
|
||||
- Limited compute (can run on CPU)
|
||||
|
||||
**Simplicity advantages**:
|
||||
- **~300 lines**: Entire model in `model.py`
|
||||
- **~300 lines**: Training loop in `train.py`
|
||||
- **Hackable**: Easy to modify
|
||||
- **No abstractions**: Pure PyTorch
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **HuggingFace Transformers**: Production use, many models
|
||||
- **Megatron-LM**: Large-scale distributed training
|
||||
- **LitGPT**: More architectures, production-ready
|
||||
- **PyTorch Lightning**: Need high-level framework
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: CUDA out of memory**
|
||||
|
||||
Reduce batch size or context length:
|
||||
```python
|
||||
batch_size = 1 # Reduce from 12
|
||||
block_size = 512 # Reduce from 1024
|
||||
gradient_accumulation_steps = 40 # Increase to maintain effective batch
|
||||
```
|
||||
|
||||
**Issue: Training too slow**
|
||||
|
||||
Enable compilation (PyTorch 2.0+):
|
||||
```python
|
||||
compile = True # 2× speedup
|
||||
```
|
||||
|
||||
Use mixed precision:
|
||||
```python
|
||||
dtype = 'bfloat16' # Or 'float16'
|
||||
```
|
||||
|
||||
**Issue: Poor generation quality**
|
||||
|
||||
Train longer:
|
||||
```python
|
||||
max_iters = 10000 # Increase from 5000
|
||||
```
|
||||
|
||||
Lower temperature:
|
||||
```python
|
||||
# In sample.py
|
||||
temperature = 0.7 # Lower from 1.0
|
||||
top_k = 200 # Add top-k sampling
|
||||
```
|
||||
|
||||
**Issue: Can't load GPT-2 weights**
|
||||
|
||||
Install transformers:
|
||||
```bash
|
||||
pip install transformers
|
||||
```
|
||||
|
||||
Check model name:
|
||||
```python
|
||||
init_from = 'gpt2' # Valid: gpt2, gpt2-medium, gpt2-large, gpt2-xl
|
||||
```
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**Model architecture**: See [references/architecture.md](references/architecture.md) for GPT block structure, multi-head attention, and MLP layers explained simply.
|
||||
|
||||
**Training loop**: See [references/training.md](references/training.md) for learning rate schedule, gradient accumulation, and distributed data parallel setup.
|
||||
|
||||
**Data preparation**: See [references/data.md](references/data.md) for tokenization strategies (character-level vs BPE) and binary format details.
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **Shakespeare (char-level)**:
|
||||
- CPU: 5 minutes
|
||||
- GPU (T4): 1 minute
|
||||
- VRAM: <1GB
|
||||
|
||||
- **GPT-2 (124M)**:
|
||||
- 1× A100: ~1 week
|
||||
- 8× A100: ~4 days
|
||||
- VRAM: ~16GB per GPU
|
||||
|
||||
- **GPT-2 Medium (350M)**:
|
||||
- 8× A100: ~2 weeks
|
||||
- VRAM: ~40GB per GPU
|
||||
|
||||
**Performance**:
|
||||
- With `compile=True`: 2× speedup
|
||||
- With `dtype=bfloat16`: 50% memory reduction
|
||||
|
||||
## Resources
|
||||
|
||||
- GitHub: https://github.com/karpathy/nanoGPT ⭐ 48,000+
|
||||
- Video: "Let's build GPT" by Andrej Karpathy
|
||||
- Paper: "Attention is All You Need" (Vaswani et al.)
|
||||
- OpenWebText: https://huggingface.co/datasets/Skylion007/openwebtext
|
||||
- Educational: Best for understanding transformers from scratch
|
||||
|
||||
|
||||
@@ -0,0 +1,382 @@
|
||||
# NanoGPT Architecture
|
||||
|
||||
## Model Structure (~300 Lines)
|
||||
|
||||
NanoGPT implements a clean GPT-2 architecture in minimal code for educational purposes.
|
||||
|
||||
### Complete Model (model.py)
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
class CausalSelfAttention(nn.Module):
|
||||
"""Multi-head masked self-attention layer."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
assert config.n_embd % config.n_head == 0
|
||||
|
||||
# Key, query, value projections for all heads (batched)
|
||||
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
|
||||
# Output projection
|
||||
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
|
||||
|
||||
# Regularization
|
||||
self.attn_dropout = nn.Dropout(config.dropout)
|
||||
self.resid_dropout = nn.Dropout(config.dropout)
|
||||
|
||||
self.n_head = config.n_head
|
||||
self.n_embd = config.n_embd
|
||||
self.dropout = config.dropout
|
||||
|
||||
# Flash attention flag
|
||||
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
|
||||
|
||||
if not self.flash:
|
||||
# Causal mask (lower triangular)
|
||||
self.register_buffer("bias", torch.tril(
|
||||
torch.ones(config.block_size, config.block_size)
|
||||
).view(1, 1, config.block_size, config.block_size))
|
||||
|
||||
def forward(self, x):
|
||||
B, T, C = x.size() # batch, seq_len, embedding_dim
|
||||
|
||||
# Calculate Q, K, V for all heads in batch
|
||||
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
||||
|
||||
# Reshape for multi-head attention
|
||||
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
|
||||
# Attention
|
||||
if self.flash:
|
||||
# Flash Attention (PyTorch 2.0+)
|
||||
y = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=None,
|
||||
dropout_p=self.dropout if self.training else 0,
|
||||
is_causal=True
|
||||
)
|
||||
else:
|
||||
# Manual attention implementation
|
||||
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
||||
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
|
||||
att = F.softmax(att, dim=-1)
|
||||
att = self.attn_dropout(att)
|
||||
y = att @ v # (B, nh, T, hs)
|
||||
|
||||
# Reassemble all head outputs
|
||||
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
||||
|
||||
# Output projection
|
||||
y = self.resid_dropout(self.c_proj(y))
|
||||
return y
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""Feedforward network (2-layer with GELU activation)."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
|
||||
self.gelu = nn.GELU()
|
||||
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.c_fc(x)
|
||||
x = self.gelu(x)
|
||||
x = self.c_proj(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
"""Transformer block (attention + MLP with residuals)."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.ln_1 = nn.LayerNorm(config.n_embd)
|
||||
self.attn = CausalSelfAttention(config)
|
||||
self.ln_2 = nn.LayerNorm(config.n_embd)
|
||||
self.mlp = MLP(config)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.attn(self.ln_1(x)) # Pre-norm + residual
|
||||
x = x + self.mlp(self.ln_2(x)) # Pre-norm + residual
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPTConfig:
|
||||
"""GPT model configuration."""
|
||||
block_size: int = 1024 # Max sequence length
|
||||
vocab_size: int = 50304 # GPT-2 vocab size (50257 rounded up for efficiency)
|
||||
n_layer: int = 12 # Number of layers
|
||||
n_head: int = 12 # Number of attention heads
|
||||
n_embd: int = 768 # Embedding dimension
|
||||
dropout: float = 0.0 # Dropout rate
|
||||
bias: bool = True # Use bias in Linear and LayerNorm layers
|
||||
|
||||
|
||||
class GPT(nn.Module):
|
||||
"""GPT Language Model."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
assert config.vocab_size is not None
|
||||
assert config.block_size is not None
|
||||
self.config = config
|
||||
|
||||
self.transformer = nn.ModuleDict(dict(
|
||||
wte=nn.Embedding(config.vocab_size, config.n_embd), # Token embeddings
|
||||
wpe=nn.Embedding(config.block_size, config.n_embd), # Position embeddings
|
||||
drop=nn.Dropout(config.dropout),
|
||||
h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
|
||||
ln_f=nn.LayerNorm(config.n_embd),
|
||||
))
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
|
||||
# Weight tying (share embeddings and output projection)
|
||||
self.transformer.wte.weight = self.lm_head.weight
|
||||
|
||||
# Initialize weights
|
||||
self.apply(self._init_weights)
|
||||
# Apply special scaled init to residual projections
|
||||
for pn, p in self.named_parameters():
|
||||
if pn.endswith('c_proj.weight'):
|
||||
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
if module.bias is not None:
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
|
||||
def forward(self, idx, targets=None):
|
||||
device = idx.device
|
||||
b, t = idx.size()
|
||||
assert t <= self.config.block_size, f"Cannot forward sequence length {t}, max is {self.config.block_size}"
|
||||
|
||||
# Generate position indices
|
||||
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # (1, t)
|
||||
|
||||
# Forward the GPT model
|
||||
tok_emb = self.transformer.wte(idx) # Token embeddings (b, t, n_embd)
|
||||
pos_emb = self.transformer.wpe(pos) # Position embeddings (1, t, n_embd)
|
||||
x = self.transformer.drop(tok_emb + pos_emb)
|
||||
|
||||
for block in self.transformer.h:
|
||||
x = block(x)
|
||||
|
||||
x = self.transformer.ln_f(x)
|
||||
|
||||
if targets is not None:
|
||||
# Training mode: compute loss
|
||||
logits = self.lm_head(x)
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
||||
else:
|
||||
# Inference mode: only compute logits for last token
|
||||
logits = self.lm_head(x[:, [-1], :]) # (b, 1, vocab_size)
|
||||
loss = None
|
||||
|
||||
return logits, loss
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
|
||||
"""Generate new tokens autoregressively."""
|
||||
for _ in range(max_new_tokens):
|
||||
# Crop context if needed
|
||||
idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
|
||||
|
||||
# Forward pass
|
||||
logits, _ = self(idx_cond)
|
||||
logits = logits[:, -1, :] / temperature # Scale by temperature
|
||||
|
||||
# Optionally crop logits to top k
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||
logits[logits < v[:, [-1]]] = -float('Inf')
|
||||
|
||||
# Sample from distribution
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
idx_next = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
# Append to sequence
|
||||
idx = torch.cat((idx, idx_next), dim=1)
|
||||
|
||||
return idx
|
||||
```
|
||||
|
||||
## Key Design Decisions
|
||||
|
||||
### 1. Pre-Norm vs Post-Norm
|
||||
|
||||
**NanoGPT uses Pre-Norm** (LayerNorm before sub-layers):
|
||||
|
||||
```python
|
||||
# Pre-norm (NanoGPT)
|
||||
x = x + attn(ln(x))
|
||||
x = x + mlp(ln(x))
|
||||
|
||||
# Post-norm (original Transformer)
|
||||
x = ln(x + attn(x))
|
||||
x = ln(x + mlp(x))
|
||||
```
|
||||
|
||||
**Why Pre-Norm?**
|
||||
- More stable training (no gradient explosion)
|
||||
- Used in GPT-2, GPT-3
|
||||
- Standard for large language models
|
||||
|
||||
### 2. Weight Tying
|
||||
|
||||
**Shared weights between embeddings and output**:
|
||||
|
||||
```python
|
||||
self.transformer.wte.weight = self.lm_head.weight
|
||||
```
|
||||
|
||||
**Why?**
|
||||
- Reduces parameters: `vocab_size × n_embd` saved
|
||||
- Improves training (same semantic space)
|
||||
- Standard in GPT-2
|
||||
|
||||
### 3. Scaled Residual Initialization
|
||||
|
||||
```python
|
||||
# Scale down residual projections by layer depth
|
||||
std = 0.02 / math.sqrt(2 * n_layer)
|
||||
torch.nn.init.normal_(c_proj.weight, mean=0.0, std=std)
|
||||
```
|
||||
|
||||
**Why?**
|
||||
- Prevents gradient explosion in deep networks
|
||||
- Each residual path contributes ~equally
|
||||
- From GPT-2 paper
|
||||
|
||||
### 4. Flash Attention
|
||||
|
||||
```python
|
||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
||||
# Use PyTorch 2.0 Flash Attention (2× faster!)
|
||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
||||
else:
|
||||
# Fallback to manual attention
|
||||
att = (q @ k.T) / sqrt(d)
|
||||
att = masked_fill(att, causal_mask, -inf)
|
||||
y = softmax(att) @ v
|
||||
```
|
||||
|
||||
**Speedup**: 2× faster with same accuracy
|
||||
|
||||
## Model Sizes
|
||||
|
||||
| Model | n_layer | n_head | n_embd | Params | Config Name |
|
||||
|-------|---------|--------|--------|--------|-------------|
|
||||
| GPT-2 Small | 12 | 12 | 768 | 124M | `gpt2` |
|
||||
| GPT-2 Medium | 24 | 16 | 1024 | 350M | `gpt2-medium` |
|
||||
| GPT-2 Large | 36 | 20 | 1280 | 774M | `gpt2-large` |
|
||||
| GPT-2 XL | 48 | 25 | 1600 | 1558M | `gpt2-xl` |
|
||||
|
||||
**NanoGPT default** (Shakespeare):
|
||||
```python
|
||||
config = GPTConfig(
|
||||
block_size=256, # Short context for char-level
|
||||
vocab_size=65, # Small vocab (a-z, A-Z, punctuation)
|
||||
n_layer=6, # Shallow network
|
||||
n_head=6,
|
||||
n_embd=384, # Small embeddings
|
||||
dropout=0.2 # Regularization
|
||||
)
|
||||
# Total: ~10M parameters
|
||||
```
|
||||
|
||||
## Attention Visualization
|
||||
|
||||
```python
|
||||
# What each token attends to (lower triangular)
|
||||
# Token t can only attend to tokens 0...t
|
||||
|
||||
Attention Pattern (causal mask):
|
||||
t=0 t=1 t=2 t=3
|
||||
t=0 ✓ - - -
|
||||
t=1 ✓ ✓ - -
|
||||
t=2 ✓ ✓ ✓ -
|
||||
t=3 ✓ ✓ ✓ ✓
|
||||
|
||||
# Prevents "cheating" by looking at future tokens
|
||||
```
|
||||
|
||||
## Residual Stream
|
||||
|
||||
**Information flow through residuals**:
|
||||
|
||||
```python
|
||||
# Input
|
||||
x = token_emb + pos_emb
|
||||
|
||||
# Block 1
|
||||
x = x + attn_1(ln(x)) # Attention adds to residual
|
||||
x = x + mlp_1(ln(x)) # MLP adds to residual
|
||||
|
||||
# Block 2
|
||||
x = x + attn_2(ln(x))
|
||||
x = x + mlp_2(ln(x))
|
||||
|
||||
# ... (repeat for all layers)
|
||||
|
||||
# Output
|
||||
logits = lm_head(ln(x))
|
||||
```
|
||||
|
||||
**Key insight**: Each layer refines the representation, residuals preserve gradients
|
||||
|
||||
## Tokenization
|
||||
|
||||
### Character-Level (Shakespeare)
|
||||
|
||||
```python
|
||||
# data/shakespeare_char/prepare.py
|
||||
text = open('input.txt', 'r').read()
|
||||
chars = sorted(list(set(text))) # ['!', ',', '.', 'A', 'B', ..., 'z']
|
||||
vocab_size = len(chars) # 65
|
||||
|
||||
stoi = {ch: i for i, ch in enumerate(chars)}
|
||||
itos = {i: ch for i, ch in enumerate(chars)}
|
||||
|
||||
# Encode
|
||||
encode = lambda s: [stoi[c] for c in s]
|
||||
decode = lambda l: ''.join([itos[i] for i in l])
|
||||
|
||||
data = torch.tensor(encode(text), dtype=torch.long)
|
||||
```
|
||||
|
||||
### BPE (GPT-2)
|
||||
|
||||
```python
|
||||
# data/openwebtext/prepare.py
|
||||
import tiktoken
|
||||
|
||||
enc = tiktoken.get_encoding("gpt2") # GPT-2 BPE tokenizer
|
||||
vocab_size = enc.n_vocab # 50257
|
||||
|
||||
# Encode
|
||||
tokens = enc.encode_ordinary("Hello world") # [15496, 995]
|
||||
|
||||
# Decode
|
||||
text = enc.decode(tokens) # "Hello world"
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/karpathy/nanoGPT ⭐ 48,000+
|
||||
- **Video**: "Let's build GPT" by Andrej Karpathy
|
||||
- **Paper**: "Attention is All You Need" (Vaswani et al.)
|
||||
- **Paper**: "Language Models are Unsupervised Multitask Learners" (GPT-2)
|
||||
- **Code walkthrough**: https://github.com/karpathy/nanoGPT/blob/master/ARCHITECTURE.md
|
||||
@@ -0,0 +1,476 @@
|
||||
# NanoGPT Data Preparation
|
||||
|
||||
## Data Format
|
||||
|
||||
NanoGPT uses **binary token files** for efficient loading:
|
||||
|
||||
```
|
||||
dataset/
|
||||
├── train.bin # Training tokens (uint16 array)
|
||||
├── val.bin # Validation tokens (uint16 array)
|
||||
└── meta.pkl # Metadata (vocab_size, mappings)
|
||||
```
|
||||
|
||||
**Why binary?**
|
||||
- 100× faster than reading text files
|
||||
- Memory-mapped loading (no RAM overhead)
|
||||
- Simple format (just token IDs)
|
||||
|
||||
## Character-Level Tokenization
|
||||
|
||||
### Shakespeare Example
|
||||
|
||||
**Input text**:
|
||||
```
|
||||
First Citizen:
|
||||
Before we proceed any further, hear me speak.
|
||||
|
||||
All:
|
||||
Speak, speak.
|
||||
```
|
||||
|
||||
**Character vocabulary** (65 total):
|
||||
```python
|
||||
chars = ['\n', ' ', '!', ',', '.', ':', ';', '?', 'A', 'B', ..., 'z']
|
||||
stoi = {'\n': 0, ' ': 1, '!': 2, ...} # char → ID
|
||||
itos = {0: '\n', 1: ' ', 2: '!', ...} # ID → char
|
||||
```
|
||||
|
||||
**Tokenization**:
|
||||
```python
|
||||
text = "First Citizen:"
|
||||
tokens = [18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 63, 43, 52, 10]
|
||||
# F=18, i=47, r=56, s=57, t=58, ' '=1, C=15, ...
|
||||
```
|
||||
|
||||
**Full preparation script**:
|
||||
|
||||
```python
|
||||
# data/shakespeare_char/prepare.py
|
||||
import os
|
||||
import requests
|
||||
import pickle
|
||||
import numpy as np
|
||||
|
||||
# Download Shakespeare dataset
|
||||
input_file = 'input.txt'
|
||||
if not os.path.exists(input_file):
|
||||
url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
|
||||
with open(input_file, 'w') as f:
|
||||
f.write(requests.get(url).text)
|
||||
|
||||
# Load text
|
||||
with open(input_file, 'r') as f:
|
||||
data = f.read()
|
||||
|
||||
print(f"Dataset size: {len(data):,} characters")
|
||||
|
||||
# Build vocabulary
|
||||
chars = sorted(list(set(data)))
|
||||
vocab_size = len(chars)
|
||||
print(f"Vocabulary: {vocab_size} unique characters")
|
||||
print(f"Characters: {''.join(chars[:20])}...")
|
||||
|
||||
# Create mappings
|
||||
stoi = {ch: i for i, ch in enumerate(chars)}
|
||||
itos = {i: ch for i, ch in enumerate(chars)}
|
||||
|
||||
# Encode full dataset
|
||||
def encode(s):
|
||||
return [stoi[c] for c in s]
|
||||
|
||||
def decode(l):
|
||||
return ''.join([itos[i] for i in l])
|
||||
|
||||
# Split train/val (90/10)
|
||||
n = len(data)
|
||||
train_data = data[:int(n * 0.9)]
|
||||
val_data = data[int(n * 0.9):]
|
||||
|
||||
# Tokenize
|
||||
train_ids = encode(train_data)
|
||||
val_ids = encode(val_data)
|
||||
|
||||
print(f"Train: {len(train_ids):,} tokens")
|
||||
print(f"Val: {len(val_ids):,} tokens")
|
||||
|
||||
# Save as binary (uint16)
|
||||
train_ids = np.array(train_ids, dtype=np.uint16)
|
||||
val_ids = np.array(val_ids, dtype=np.uint16)
|
||||
|
||||
train_ids.tofile('train.bin')
|
||||
val_ids.tofile('val.bin')
|
||||
|
||||
# Save metadata
|
||||
meta = {
|
||||
'vocab_size': vocab_size,
|
||||
'itos': itos,
|
||||
'stoi': stoi,
|
||||
}
|
||||
|
||||
with open('meta.pkl', 'wb') as f:
|
||||
pickle.dump(meta, f)
|
||||
|
||||
print("Saved train.bin, val.bin, meta.pkl")
|
||||
```
|
||||
|
||||
**Output**:
|
||||
```
|
||||
Dataset size: 1,115,394 characters
|
||||
Vocabulary: 65 unique characters
|
||||
Characters: !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
|
||||
Train: 1,003,854 tokens
|
||||
Val: 111,540 tokens
|
||||
Saved train.bin, val.bin, meta.pkl
|
||||
```
|
||||
|
||||
### Custom Character Dataset
|
||||
|
||||
```python
|
||||
# For your own text dataset
|
||||
text = open('my_data.txt', 'r').read()
|
||||
|
||||
# Build vocab
|
||||
chars = sorted(list(set(text)))
|
||||
vocab_size = len(chars)
|
||||
|
||||
# Create mappings
|
||||
stoi = {ch: i for i, ch in enumerate(chars)}
|
||||
itos = {i: ch for i, ch in enumerate(chars)}
|
||||
|
||||
# Encode
|
||||
encode = lambda s: [stoi[c] for c in s]
|
||||
decode = lambda l: ''.join([itos[i] for i in l])
|
||||
|
||||
# Split and save
|
||||
data = np.array(encode(text), dtype=np.uint16)
|
||||
n = len(data)
|
||||
train = data[:int(n*0.9)]
|
||||
val = data[int(n*0.9):]
|
||||
|
||||
train.tofile('data/custom/train.bin')
|
||||
val.tofile('data/custom/val.bin')
|
||||
|
||||
# Save meta
|
||||
with open('data/custom/meta.pkl', 'wb') as f:
|
||||
pickle.dump({'vocab_size': vocab_size, 'itos': itos, 'stoi': stoi}, f)
|
||||
```
|
||||
|
||||
## BPE (Byte Pair Encoding)
|
||||
|
||||
### OpenWebText with GPT-2 Tokenizer
|
||||
|
||||
**BPE advantages**:
|
||||
- Handles rare words better (subword units)
|
||||
- Standard for GPT-2, GPT-3
|
||||
- Vocabulary: 50,257 tokens
|
||||
|
||||
**Preparation script**:
|
||||
|
||||
```python
|
||||
# data/openwebtext/prepare.py
|
||||
import os
|
||||
import numpy as np
|
||||
import tiktoken
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
# Number of workers for parallel processing
|
||||
num_proc = 8
|
||||
num_proc_load_dataset = num_proc
|
||||
|
||||
# Download OpenWebText dataset
|
||||
dataset = load_dataset("openwebtext", num_proc=num_proc_load_dataset)
|
||||
|
||||
# Use GPT-2 tokenizer
|
||||
enc = tiktoken.get_encoding("gpt2")
|
||||
|
||||
def process(example):
|
||||
"""Tokenize a single example."""
|
||||
ids = enc.encode_ordinary(example['text']) # Tokenize
|
||||
ids.append(enc.eot_token) # Add end-of-text token
|
||||
out = {'ids': ids, 'len': len(ids)}
|
||||
return out
|
||||
|
||||
# Tokenize entire dataset (parallel)
|
||||
tokenized = dataset.map(
|
||||
process,
|
||||
remove_columns=['text'],
|
||||
desc="Tokenizing",
|
||||
num_proc=num_proc,
|
||||
)
|
||||
|
||||
# Concatenate all into one big array
|
||||
train_ids = np.concatenate([
|
||||
np.array(sample['ids'], dtype=np.uint16)
|
||||
for sample in tqdm(tokenized['train'], desc="Concatenating")
|
||||
])
|
||||
|
||||
print(f"Total tokens: {len(train_ids):,}") # ~9 billion tokens
|
||||
|
||||
# Save train.bin
|
||||
train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin'))
|
||||
|
||||
# Create val.bin (sample from train)
|
||||
# Take first 5000 documents for validation
|
||||
val_ids = np.concatenate([
|
||||
np.array(sample['ids'], dtype=np.uint16)
|
||||
for sample in tokenized['train'][:5000]
|
||||
])
|
||||
val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin'))
|
||||
|
||||
# Save metadata
|
||||
import pickle
|
||||
meta = {
|
||||
'vocab_size': enc.n_vocab,
|
||||
'eot_token': enc.eot_token,
|
||||
}
|
||||
with open(os.path.join(os.path.dirname(__file__), 'meta.pkl'), 'wb') as f:
|
||||
pickle.dump(meta, f)
|
||||
|
||||
print(f"Train tokens: {len(train_ids):,}")
|
||||
print(f"Val tokens: {len(val_ids):,}")
|
||||
print(f"Vocab size: {enc.n_vocab:,}")
|
||||
```
|
||||
|
||||
**Output**:
|
||||
```
|
||||
Total tokens: 9,035,582,198
|
||||
Train tokens: 9,035,582,198
|
||||
Val tokens: 4,123,676
|
||||
Vocab size: 50,257
|
||||
```
|
||||
|
||||
**Time**: 1-2 hours on 8-core CPU
|
||||
|
||||
**Disk usage**:
|
||||
- train.bin: ~18 GB (9B tokens × 2 bytes)
|
||||
- val.bin: ~8 MB
|
||||
- Original text: ~54 GB
|
||||
|
||||
### BPE Tokenization Example
|
||||
|
||||
```python
|
||||
import tiktoken
|
||||
|
||||
enc = tiktoken.get_encoding("gpt2")
|
||||
|
||||
# Tokenize
|
||||
text = "Hello world! This is a test."
|
||||
tokens = enc.encode_ordinary(text)
|
||||
print(tokens)
|
||||
# [15496, 995, 0, 770, 318, 257, 1332, 13]
|
||||
|
||||
# Decode
|
||||
decoded = enc.decode(tokens)
|
||||
print(decoded)
|
||||
# "Hello world! This is a test."
|
||||
|
||||
# Token → text
|
||||
print([enc.decode([t]) for t in tokens])
|
||||
# ['Hello', ' world', '!', ' This', ' is', ' a', ' test', '.']
|
||||
```
|
||||
|
||||
**Subword splitting**:
|
||||
```python
|
||||
# Rare word "electroencephalography" is split
|
||||
tokens = enc.encode_ordinary("electroencephalography")
|
||||
print([enc.decode([t]) for t in tokens])
|
||||
# ['elect', 'ro', 'ence', 'ph', 'al', 'ography']
|
||||
```
|
||||
|
||||
## Data Loading
|
||||
|
||||
### Memory-Mapped Loading (Efficient)
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# Load data (memory-mapped, no RAM overhead)
|
||||
data_dir = 'data/shakespeare_char'
|
||||
train_data = np.memmap(
|
||||
os.path.join(data_dir, 'train.bin'),
|
||||
dtype=np.uint16,
|
||||
mode='r'
|
||||
)
|
||||
|
||||
print(f"Loaded {len(train_data):,} tokens") # No actual read yet!
|
||||
|
||||
# Get batch (read on-demand)
|
||||
def get_batch(split):
|
||||
data = train_data if split == 'train' else val_data
|
||||
|
||||
# Random indices
|
||||
ix = torch.randint(len(data) - block_size, (batch_size,))
|
||||
|
||||
# Extract sequences
|
||||
x = torch.stack([torch.from_numpy(data[i:i+block_size].astype(np.int64)) for i in ix])
|
||||
y = torch.stack([torch.from_numpy(data[i+1:i+1+block_size].astype(np.int64)) for i in ix])
|
||||
|
||||
# Move to GPU
|
||||
x, y = x.to('cuda'), y.to('cuda')
|
||||
|
||||
return x, y
|
||||
|
||||
# Usage
|
||||
X, Y = get_batch('train')
|
||||
# X shape: (batch_size, block_size)
|
||||
# Y shape: (batch_size, block_size)
|
||||
```
|
||||
|
||||
**Memory efficiency**:
|
||||
- 9 GB dataset loaded with ~0 MB RAM
|
||||
- Only batch data is loaded into memory
|
||||
|
||||
### Data Loader (PyTorch)
|
||||
|
||||
```python
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
class TokenDataset(Dataset):
|
||||
def __init__(self, data_path, block_size):
|
||||
self.data = np.memmap(data_path, dtype=np.uint16, mode='r')
|
||||
self.block_size = block_size
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data) - self.block_size
|
||||
|
||||
def __getitem__(self, idx):
|
||||
x = torch.from_numpy(self.data[idx:idx+self.block_size].astype(np.int64))
|
||||
y = torch.from_numpy(self.data[idx+1:idx+1+self.block_size].astype(np.int64))
|
||||
return x, y
|
||||
|
||||
# Create data loader
|
||||
train_dataset = TokenDataset('data/shakespeare_char/train.bin', block_size=256)
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=64,
|
||||
shuffle=True,
|
||||
num_workers=4,
|
||||
pin_memory=True
|
||||
)
|
||||
|
||||
# Usage
|
||||
for X, Y in train_loader:
|
||||
X, Y = X.to('cuda'), Y.to('cuda')
|
||||
# Train...
|
||||
```
|
||||
|
||||
## Custom Datasets
|
||||
|
||||
### Wikipedia
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
# Load Wikipedia
|
||||
dataset = load_dataset("wikipedia", "20220301.en", num_proc=8)
|
||||
|
||||
# Tokenize
|
||||
enc = tiktoken.get_encoding("gpt2")
|
||||
|
||||
def tokenize(example):
|
||||
ids = enc.encode_ordinary(example['text'])
|
||||
return {'ids': ids, 'len': len(ids)}
|
||||
|
||||
tokenized = dataset.map(tokenize, num_proc=8, remove_columns=['text', 'title'])
|
||||
|
||||
# Save
|
||||
train_ids = np.concatenate([np.array(x['ids'], dtype=np.uint16) for x in tokenized['train']])
|
||||
train_ids.tofile('data/wikipedia/train.bin')
|
||||
```
|
||||
|
||||
### Code (GitHub)
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
# Load code dataset (The Stack)
|
||||
dataset = load_dataset("bigcode/the-stack", data_dir="data/python", num_proc=8)
|
||||
|
||||
# Tokenize (same as above)
|
||||
enc = tiktoken.get_encoding("gpt2")
|
||||
# ... tokenize and save
|
||||
```
|
||||
|
||||
### Custom Text Files
|
||||
|
||||
```python
|
||||
# Load custom text files
|
||||
import glob
|
||||
|
||||
files = glob.glob('my_dataset/*.txt')
|
||||
text = ''
|
||||
|
||||
for file in files:
|
||||
with open(file, 'r') as f:
|
||||
text += f.read() + '\n'
|
||||
|
||||
# Character-level
|
||||
chars = sorted(list(set(text)))
|
||||
stoi = {ch: i for i, ch in enumerate(chars)}
|
||||
data = np.array([stoi[c] for c in text], dtype=np.uint16)
|
||||
|
||||
# Split and save
|
||||
n = len(data)
|
||||
train = data[:int(n*0.9)]
|
||||
val = data[int(n*0.9):]
|
||||
|
||||
train.tofile('data/custom/train.bin')
|
||||
val.tofile('data/custom/val.bin')
|
||||
|
||||
# Meta
|
||||
with open('data/custom/meta.pkl', 'wb') as f:
|
||||
pickle.dump({'vocab_size': len(chars), 'itos': {i: ch for i, ch in enumerate(chars)}, 'stoi': stoi}, f)
|
||||
```
|
||||
|
||||
## Data Augmentation (Advanced)
|
||||
|
||||
### Random Masking (BERT-style)
|
||||
|
||||
```python
|
||||
def random_mask(tokens, mask_prob=0.15):
|
||||
"""Randomly mask tokens for denoising objective."""
|
||||
mask = torch.rand(tokens.shape) < mask_prob
|
||||
tokens[mask] = mask_token_id
|
||||
return tokens
|
||||
|
||||
# Usage in training
|
||||
X, Y = get_batch('train')
|
||||
X_masked = random_mask(X.clone())
|
||||
logits, loss = model(X_masked, Y) # Predict original from masked
|
||||
```
|
||||
|
||||
### Document Shuffling
|
||||
|
||||
```python
|
||||
# Shuffle document order (not token order)
|
||||
# Better generalization than sequential documents
|
||||
|
||||
import random
|
||||
|
||||
# Load documents
|
||||
docs = dataset['train']
|
||||
random.shuffle(docs)
|
||||
|
||||
# Concatenate shuffled
|
||||
train_ids = np.concatenate([np.array(doc['ids'], dtype=np.uint16) for doc in docs])
|
||||
```
|
||||
|
||||
## Benchmarks
|
||||
|
||||
| Dataset | Tokens | Vocab | Prep Time | Disk Size |
|
||||
|---------|--------|-------|-----------|-----------|
|
||||
| Shakespeare (char) | 1M | 65 | 1 sec | 2 MB |
|
||||
| TinyStories | 250M | 50K | 5 min | 500 MB |
|
||||
| OpenWebText | 9B | 50K | 90 min | 18 GB |
|
||||
| The Pile | 300B | 50K | ~2 days | 600 GB |
|
||||
|
||||
## Resources
|
||||
|
||||
- Data preparation scripts: https://github.com/karpathy/nanoGPT/tree/master/data
|
||||
- Tiktoken (BPE tokenizer): https://github.com/openai/tiktoken
|
||||
- HuggingFace datasets: https://huggingface.co/datasets
|
||||
- OpenWebText: https://huggingface.co/datasets/Skylion007/openwebtext
|
||||
- The Stack (code): https://huggingface.co/datasets/bigcode/the-stack
|
||||
@@ -0,0 +1,564 @@
|
||||
# NanoGPT Training Guide
|
||||
|
||||
## Training Loop (~300 Lines)
|
||||
|
||||
NanoGPT's `train.py` is a self-contained training script with minimal dependencies.
|
||||
|
||||
### Complete Training Script Structure
|
||||
|
||||
```python
|
||||
# train.py (simplified)
|
||||
import os
|
||||
import time
|
||||
import math
|
||||
import pickle
|
||||
import torch
|
||||
from model import GPTConfig, GPT
|
||||
|
||||
# Training config
|
||||
batch_size = 12 # Micro batch size
|
||||
block_size = 1024 # Context length
|
||||
gradient_accumulation_steps = 5 * 8 # ~60K tokens per batch
|
||||
|
||||
# Model config
|
||||
n_layer = 12
|
||||
n_head = 12
|
||||
n_embd = 768
|
||||
dropout = 0.0
|
||||
|
||||
# Optimizer config
|
||||
learning_rate = 6e-4
|
||||
max_iters = 600000
|
||||
weight_decay = 1e-1
|
||||
beta1 = 0.9
|
||||
beta2 = 0.95
|
||||
grad_clip = 1.0
|
||||
|
||||
# Learning rate schedule
|
||||
warmup_iters = 2000
|
||||
lr_decay_iters = 600000
|
||||
min_lr = 6e-5
|
||||
|
||||
# System
|
||||
device = 'cuda'
|
||||
dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float16'
|
||||
compile = True # PyTorch 2.0
|
||||
|
||||
# Data loader
|
||||
def get_batch(split):
|
||||
data = train_data if split == 'train' else val_data
|
||||
ix = torch.randint(len(data) - block_size, (batch_size,))
|
||||
x = torch.stack([data[i:i+block_size] for i in ix])
|
||||
y = torch.stack([data[i+1:i+1+block_size] for i in ix])
|
||||
x, y = x.to(device), y.to(device)
|
||||
return x, y
|
||||
|
||||
# Learning rate schedule
|
||||
def get_lr(it):
|
||||
# Warmup
|
||||
if it < warmup_iters:
|
||||
return learning_rate * it / warmup_iters
|
||||
# Decay to min_lr
|
||||
if it > lr_decay_iters:
|
||||
return min_lr
|
||||
# Cosine decay
|
||||
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
|
||||
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
|
||||
return min_lr + coeff * (learning_rate - min_lr)
|
||||
|
||||
# Init model
|
||||
model = GPT(GPTConfig())
|
||||
model.to(device)
|
||||
|
||||
# Compile model (PyTorch 2.0)
|
||||
if compile:
|
||||
print("Compiling model...")
|
||||
model = torch.compile(model)
|
||||
|
||||
# Optimizer
|
||||
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device)
|
||||
|
||||
# Training loop
|
||||
for iter_num in range(max_iters):
|
||||
# Set learning rate
|
||||
lr = get_lr(iter_num)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
# Gradient accumulation
|
||||
for micro_step in range(gradient_accumulation_steps):
|
||||
X, Y = get_batch('train')
|
||||
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
|
||||
logits, loss = model(X, Y)
|
||||
loss = loss / gradient_accumulation_steps
|
||||
loss.backward()
|
||||
|
||||
# Clip gradients
|
||||
if grad_clip != 0.0:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
||||
|
||||
# Update weights
|
||||
optimizer.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# Logging
|
||||
if iter_num % 100 == 0:
|
||||
print(f"iter {iter_num}: loss {loss.item():.4f}, lr {lr:.2e}")
|
||||
```
|
||||
|
||||
## Data Preparation
|
||||
|
||||
### Shakespeare Character-Level
|
||||
|
||||
```bash
|
||||
# Step 1: Download Shakespeare
|
||||
cd data/shakespeare_char
|
||||
python prepare.py
|
||||
|
||||
# Creates:
|
||||
# - train.bin (90% of data, ~1MB)
|
||||
# - val.bin (10% of data, ~110KB)
|
||||
# - meta.pkl (vocab info)
|
||||
```
|
||||
|
||||
**prepare.py**:
|
||||
```python
|
||||
import os
|
||||
import pickle
|
||||
import requests
|
||||
import numpy as np
|
||||
|
||||
# Download
|
||||
input_file = 'input.txt'
|
||||
if not os.path.exists(input_file):
|
||||
url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
|
||||
with open(input_file, 'w') as f:
|
||||
f.write(requests.get(url).text)
|
||||
|
||||
# Read and process
|
||||
with open(input_file, 'r') as f:
|
||||
data = f.read()
|
||||
|
||||
print(f"Length: {len(data):,} characters")
|
||||
|
||||
# Create vocabulary
|
||||
chars = sorted(list(set(data)))
|
||||
vocab_size = len(chars)
|
||||
print(f"Vocab size: {vocab_size}")
|
||||
|
||||
# Create mappings
|
||||
stoi = {ch: i for i, ch in enumerate(chars)}
|
||||
itos = {i: ch for i, ch in enumerate(chars)}
|
||||
|
||||
# Encode dataset
|
||||
data_ids = [stoi[c] for c in data]
|
||||
|
||||
# Train/val split
|
||||
n = len(data_ids)
|
||||
train_ids = data_ids[:int(n*0.9)]
|
||||
val_ids = data_ids[int(n*0.9):]
|
||||
|
||||
# Save as numpy arrays
|
||||
train_ids = np.array(train_ids, dtype=np.uint16)
|
||||
val_ids = np.array(val_ids, dtype=np.uint16)
|
||||
train_ids.tofile('train.bin')
|
||||
val_ids.tofile('val.bin')
|
||||
|
||||
# Save metadata
|
||||
meta = {'vocab_size': vocab_size, 'itos': itos, 'stoi': stoi}
|
||||
with open('meta.pkl', 'wb') as f:
|
||||
pickle.dump(meta, f)
|
||||
```
|
||||
|
||||
### OpenWebText (GPT-2 Reproduction)
|
||||
|
||||
```bash
|
||||
# Step 1: Download OpenWebText (~12GB compressed)
|
||||
cd data/openwebtext
|
||||
python prepare.py
|
||||
|
||||
# Warning: Takes 1-2 hours, creates ~54GB of tokenized data
|
||||
```
|
||||
|
||||
**prepare.py**:
|
||||
```python
|
||||
import os
|
||||
import numpy as np
|
||||
import tiktoken
|
||||
from datasets import load_dataset
|
||||
|
||||
# Download dataset
|
||||
dataset = load_dataset("openwebtext", num_proc=8)
|
||||
|
||||
# Use GPT-2 tokenizer
|
||||
enc = tiktoken.get_encoding("gpt2")
|
||||
|
||||
def tokenize(example):
|
||||
ids = enc.encode_ordinary(example['text'])
|
||||
ids.append(enc.eot_token) # Add <|endoftext|>
|
||||
return {'ids': ids, 'len': len(ids)}
|
||||
|
||||
# Tokenize (parallel)
|
||||
tokenized = dataset.map(
|
||||
tokenize,
|
||||
remove_columns=['text'],
|
||||
desc="Tokenizing",
|
||||
num_proc=8
|
||||
)
|
||||
|
||||
# Concatenate all tokens
|
||||
train_ids = np.concatenate([np.array(x['ids'], dtype=np.uint16) for x in tokenized['train']])
|
||||
print(f"Train tokens: {len(train_ids):,}") # ~9B tokens
|
||||
|
||||
# Save
|
||||
train_ids.tofile('train.bin')
|
||||
|
||||
# Validation set (sample)
|
||||
val_ids = np.concatenate([np.array(x['ids'], dtype=np.uint16) for x in tokenized['train'][:5000]])
|
||||
val_ids.tofile('val.bin')
|
||||
|
||||
# Save metadata
|
||||
meta = {'vocab_size': enc.n_vocab, 'eot_token': enc.eot_token}
|
||||
with open('meta.pkl', 'wb') as f:
|
||||
pickle.dump(meta, f)
|
||||
```
|
||||
|
||||
## Learning Rate Schedules
|
||||
|
||||
### Cosine Decay with Warmup (GPT-2 style)
|
||||
|
||||
```python
|
||||
def get_lr(it):
|
||||
# 1) Linear warmup
|
||||
if it < warmup_iters:
|
||||
return learning_rate * it / warmup_iters
|
||||
|
||||
# 2) Constant at min_lr after decay
|
||||
if it > lr_decay_iters:
|
||||
return min_lr
|
||||
|
||||
# 3) Cosine decay in between
|
||||
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
|
||||
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
|
||||
return min_lr + coeff * (learning_rate - min_lr)
|
||||
|
||||
# Example values
|
||||
learning_rate = 6e-4 # Peak LR
|
||||
min_lr = 6e-5 # Final LR (10% of peak)
|
||||
warmup_iters = 2000 # Warmup steps
|
||||
lr_decay_iters = 600000 # Total training steps
|
||||
```
|
||||
|
||||
**Visualization**:
|
||||
```
|
||||
LR
|
||||
^
|
||||
| Peak (6e-4)
|
||||
| /‾‾‾‾‾‾‾‾‾‾\
|
||||
| / \
|
||||
| / \_____ Min (6e-5)
|
||||
| /
|
||||
|/________________> Iteration
|
||||
Warmup Cosine Const
|
||||
(2K) (598K)
|
||||
```
|
||||
|
||||
### Constant LR with Warmup (Simple)
|
||||
|
||||
```python
|
||||
def get_lr(it):
|
||||
if it < warmup_iters:
|
||||
return learning_rate * it / warmup_iters
|
||||
return learning_rate
|
||||
|
||||
# Good for small experiments
|
||||
```
|
||||
|
||||
## Gradient Accumulation
|
||||
|
||||
**Effective batch size** = `batch_size × gradient_accumulation_steps × num_gpus`
|
||||
|
||||
```python
|
||||
# Config
|
||||
batch_size = 12 # Per-GPU micro batch
|
||||
gradient_accumulation_steps = 40 # Accumulate gradients
|
||||
# Effective batch: 12 × 40 = 480 sequences = ~0.5M tokens
|
||||
|
||||
# Training loop
|
||||
optimizer.zero_grad()
|
||||
for micro_step in range(gradient_accumulation_steps):
|
||||
X, Y = get_batch('train')
|
||||
logits, loss = model(X, Y)
|
||||
loss = loss / gradient_accumulation_steps # Scale loss
|
||||
loss.backward() # Accumulate gradients
|
||||
|
||||
# Update once
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Why?**
|
||||
- Simulates large batch size without OOM
|
||||
- GPT-2 (124M) uses effective batch ~0.5M tokens
|
||||
- More stable training
|
||||
|
||||
## Mixed Precision Training
|
||||
|
||||
### BF16 (Best for A100/H100)
|
||||
|
||||
```python
|
||||
# Enable bfloat16
|
||||
dtype = torch.bfloat16
|
||||
|
||||
# Training loop
|
||||
for iter in range(max_iters):
|
||||
X, Y = get_batch('train')
|
||||
|
||||
# Forward in BF16
|
||||
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
|
||||
logits, loss = model(X, Y)
|
||||
|
||||
# Backward in FP32 (automatic)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Advantages**:
|
||||
- No gradient scaler needed
|
||||
- Same dynamic range as FP32
|
||||
- 2× faster, 50% memory reduction
|
||||
|
||||
### FP16 (V100, older GPUs)
|
||||
|
||||
```python
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
|
||||
scaler = GradScaler()
|
||||
|
||||
for iter in range(max_iters):
|
||||
X, Y = get_batch('train')
|
||||
|
||||
# Forward in FP16
|
||||
with autocast():
|
||||
logits, loss = model(X, Y)
|
||||
|
||||
# Scale loss, backward
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
# Unscale, clip gradients
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
||||
|
||||
# Update weights
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
```
|
||||
|
||||
## Distributed Data Parallel (DDP)
|
||||
|
||||
### Single Node, Multiple GPUs
|
||||
|
||||
```python
|
||||
# train.py (DDP version)
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
# Initialize
|
||||
dist.init_process_group(backend='nccl')
|
||||
ddp_rank = int(os.environ['RANK'])
|
||||
ddp_local_rank = int(os.environ['LOCAL_RANK'])
|
||||
ddp_world_size = int(os.environ['WORLD_SIZE'])
|
||||
device = f'cuda:{ddp_local_rank}'
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
# Model
|
||||
model = GPT(GPTConfig())
|
||||
model.to(device)
|
||||
model = DDP(model, device_ids=[ddp_local_rank])
|
||||
|
||||
# Training loop (same as before, DDP handles gradient sync)
|
||||
for iter in range(max_iters):
|
||||
X, Y = get_batch('train') # Each rank gets different data
|
||||
logits, loss = model(X, Y)
|
||||
loss.backward() # DDP syncs gradients across GPUs
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Launch**:
|
||||
```bash
|
||||
# 8 GPUs on single node
|
||||
torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py
|
||||
```
|
||||
|
||||
### Multi-Node Training
|
||||
|
||||
```bash
|
||||
# Node 0 (master)
|
||||
torchrun --nproc_per_node=8 \
|
||||
--nnodes=4 --node_rank=0 \
|
||||
--master_addr=192.168.1.100 --master_port=29500 \
|
||||
train.py config/train_gpt2.py
|
||||
|
||||
# Node 1-3 (workers)
|
||||
torchrun --nproc_per_node=8 \
|
||||
--nnodes=4 --node_rank=$RANK \
|
||||
--master_addr=192.168.1.100 --master_port=29500 \
|
||||
train.py config/train_gpt2.py
|
||||
```
|
||||
|
||||
## Checkpointing
|
||||
|
||||
### Save Checkpoint
|
||||
|
||||
```python
|
||||
# Save every N iterations
|
||||
if iter_num % 5000 == 0:
|
||||
checkpoint = {
|
||||
'model': model.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'model_args': model_args,
|
||||
'iter_num': iter_num,
|
||||
'best_val_loss': best_val_loss,
|
||||
'config': config,
|
||||
}
|
||||
torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{iter_num}.pt'))
|
||||
```
|
||||
|
||||
### Resume from Checkpoint
|
||||
|
||||
```python
|
||||
# Load checkpoint
|
||||
init_from = 'resume' # or 'gpt2', 'gpt2-medium', etc.
|
||||
|
||||
if init_from == 'resume':
|
||||
ckpt_path = os.path.join(out_dir, 'ckpt_latest.pt')
|
||||
checkpoint = torch.load(ckpt_path, map_location=device)
|
||||
|
||||
# Restore model
|
||||
model_args = checkpoint['model_args']
|
||||
model = GPT(GPTConfig(**model_args))
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
|
||||
# Restore optimizer
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
|
||||
# Restore iteration counter
|
||||
iter_num = checkpoint['iter_num']
|
||||
best_val_loss = checkpoint['best_val_loss']
|
||||
```
|
||||
|
||||
## Fine-Tuning Pretrained Models
|
||||
|
||||
### Load OpenAI GPT-2 Weights
|
||||
|
||||
```python
|
||||
# model.py - from_pretrained method
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_type):
|
||||
"""Load pretrained GPT-2 model weights from HuggingFace."""
|
||||
from transformers import GPT2LMHeadModel
|
||||
|
||||
# Download from HuggingFace
|
||||
model_hf = GPT2LMHeadModel.from_pretrained(model_type)
|
||||
sd_hf = model_hf.state_dict()
|
||||
|
||||
# Filter out keys we don't need
|
||||
sd_hf_keys = [k for k in sd_hf.keys() if not k.endswith('.attn.masked_bias')]
|
||||
sd_hf_keys = [k for k in sd_hf_keys if not k.endswith('.attn.bias')]
|
||||
|
||||
# Create our model
|
||||
config = GPTConfig.from_model_type(model_type)
|
||||
model = GPT(config)
|
||||
sd = model.state_dict()
|
||||
|
||||
# Copy weights (transpose Conv1D → Linear)
|
||||
for k in sd_hf_keys:
|
||||
if any([k.endswith(w) for w in ['.c_attn.weight', '.c_proj.weight', '.c_fc.weight']]):
|
||||
sd[k] = sd_hf[k].t() # Transpose
|
||||
else:
|
||||
sd[k] = sd_hf[k] # Direct copy
|
||||
|
||||
model.load_state_dict(sd)
|
||||
return model
|
||||
|
||||
# Usage
|
||||
model = GPT.from_pretrained('gpt2') # Load GPT-2 (124M)
|
||||
```
|
||||
|
||||
### Fine-Tune on Custom Data
|
||||
|
||||
```python
|
||||
# config/finetune_shakespeare.py
|
||||
init_from = 'gpt2' # Start from GPT-2
|
||||
dataset = 'shakespeare_char'
|
||||
|
||||
# Fine-tuning hyperparameters
|
||||
learning_rate = 3e-5 # Lower LR for fine-tuning
|
||||
max_iters = 2000 # Short fine-tuning
|
||||
warmup_iters = 100
|
||||
|
||||
# Regularization
|
||||
weight_decay = 1e-1
|
||||
dropout = 0.2 # Add dropout
|
||||
|
||||
# Run
|
||||
# python train.py config/finetune_shakespeare.py
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
|
||||
### Perplexity
|
||||
|
||||
```python
|
||||
@torch.no_grad()
|
||||
def estimate_loss():
|
||||
model.eval()
|
||||
losses = torch.zeros(eval_iters)
|
||||
|
||||
for k in range(eval_iters):
|
||||
X, Y = get_batch('val')
|
||||
logits, loss = model(X, Y)
|
||||
losses[k] = loss.item()
|
||||
|
||||
model.train()
|
||||
return losses.mean()
|
||||
|
||||
# Usage
|
||||
val_loss = estimate_loss()
|
||||
perplexity = math.exp(val_loss)
|
||||
print(f"Val perplexity: {perplexity:.2f}")
|
||||
```
|
||||
|
||||
### Sample Generation
|
||||
|
||||
```python
|
||||
# sample.py
|
||||
model.eval()
|
||||
|
||||
start = "ROMEO:" # Prompt
|
||||
start_ids = encode(start)
|
||||
x = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]
|
||||
|
||||
# Generate
|
||||
with torch.no_grad():
|
||||
y = model.generate(x, max_new_tokens=500, temperature=0.8, top_k=200)
|
||||
|
||||
print(decode(y[0].tolist()))
|
||||
```
|
||||
|
||||
## Training Times
|
||||
|
||||
| Setup | Model | Hardware | Batch Size | Time to Perplexity 10 |
|
||||
|-------|-------|----------|------------|----------------------|
|
||||
| Shakespeare | 10M | 1× CPU | 64 | 5 minutes |
|
||||
| Shakespeare | 10M | 1× T4 GPU | 64 | 1 minute |
|
||||
| OpenWebText | 124M | 1× A100 | 480 | 7 days |
|
||||
| OpenWebText | 124M | 8× A100 | 3840 | 4 days |
|
||||
| OpenWebText | 350M | 8× A100 | 1920 | 14 days |
|
||||
|
||||
## Resources
|
||||
|
||||
- Training script: https://github.com/karpathy/nanoGPT/blob/master/train.py
|
||||
- Configs: https://github.com/karpathy/nanoGPT/tree/master/config
|
||||
- Video walkthrough: "Let's build GPT" (training section)
|
||||
- GPT-2 paper: https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf
|
||||
@@ -0,0 +1,260 @@
|
||||
---
|
||||
name: rwkv-architecture
|
||||
description: RNN+Transformer hybrid with O(n) inference. Linear time, infinite context, no KV cache. Train like GPT (parallel), infer like RNN (sequential). Linux Foundation AI project. Production at Windows, Office, NeMo. RWKV-7 (March 2025). Models up to 14B parameters.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [RWKV, Model Architecture, RNN, Transformer Hybrid, Linear Complexity, Infinite Context, Efficient Inference, Linux Foundation, Alternative Architecture]
|
||||
dependencies: [rwkv, torch, transformers]
|
||||
---
|
||||
|
||||
# RWKV - Receptance Weighted Key Value
|
||||
|
||||
## Quick start
|
||||
|
||||
RWKV (RwaKuv) combines Transformer parallelization (training) with RNN efficiency (inference).
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
# Install PyTorch
|
||||
pip install torch --upgrade --extra-index-url https://download.pytorch.org/whl/cu121
|
||||
|
||||
# Install dependencies
|
||||
pip install pytorch-lightning==1.9.5 deepspeed wandb ninja --upgrade
|
||||
|
||||
# Install RWKV
|
||||
pip install rwkv
|
||||
```
|
||||
|
||||
**Basic usage** (GPT mode + RNN mode):
|
||||
```python
|
||||
import os
|
||||
from rwkv.model import RWKV
|
||||
|
||||
os.environ["RWKV_JIT_ON"] = '1'
|
||||
os.environ["RWKV_CUDA_ON"] = '1' # Use CUDA kernel for speed
|
||||
|
||||
# Load model
|
||||
model = RWKV(
|
||||
model='/path/to/RWKV-4-Pile-1B5-20220903-8040',
|
||||
strategy='cuda fp16'
|
||||
)
|
||||
|
||||
# GPT mode (parallel processing)
|
||||
out, state = model.forward([187, 510, 1563, 310, 247], None)
|
||||
print(out.detach().cpu().numpy()) # Logits
|
||||
|
||||
# RNN mode (sequential processing, same result)
|
||||
out, state = model.forward([187, 510], None) # First 2 tokens
|
||||
out, state = model.forward([1563], state) # Next token
|
||||
out, state = model.forward([310, 247], state) # Last tokens
|
||||
print(out.detach().cpu().numpy()) # Same logits as above!
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Text generation (streaming)
|
||||
|
||||
**Efficient token-by-token generation**:
|
||||
```python
|
||||
from rwkv.model import RWKV
|
||||
from rwkv.utils import PIPELINE
|
||||
|
||||
model = RWKV(model='RWKV-4-Pile-14B-20230313-ctx8192-test1050', strategy='cuda fp16')
|
||||
pipeline = PIPELINE(model, "20B_tokenizer.json")
|
||||
|
||||
# Initial prompt
|
||||
prompt = "The future of AI is"
|
||||
state = None
|
||||
|
||||
# Generate token by token
|
||||
for token in prompt:
|
||||
out, state = pipeline.model.forward(pipeline.encode(token), state)
|
||||
|
||||
# Continue generation
|
||||
for _ in range(100):
|
||||
out, state = pipeline.model.forward(None, state)
|
||||
token = pipeline.sample_logits(out)
|
||||
print(pipeline.decode(token), end='', flush=True)
|
||||
```
|
||||
|
||||
**Key advantage**: Constant memory per token (no growing KV cache)
|
||||
|
||||
### Workflow 2: Long context processing (infinite context)
|
||||
|
||||
**Process million-token sequences**:
|
||||
```python
|
||||
model = RWKV(model='RWKV-4-Pile-14B', strategy='cuda fp16')
|
||||
|
||||
# Process very long document
|
||||
state = None
|
||||
long_document = load_document() # e.g., 1M tokens
|
||||
|
||||
# Stream through entire document
|
||||
for chunk in chunks(long_document, chunk_size=1024):
|
||||
out, state = model.forward(chunk, state)
|
||||
|
||||
# State now contains information from entire 1M token document
|
||||
# Memory usage: O(1) (constant, not O(n)!)
|
||||
```
|
||||
|
||||
### Workflow 3: Fine-tuning RWKV
|
||||
|
||||
**Standard fine-tuning workflow**:
|
||||
```python
|
||||
# Training script
|
||||
import pytorch_lightning as pl
|
||||
from rwkv.model import RWKV
|
||||
from rwkv.trainer import RWKVTrainer
|
||||
|
||||
# Configure model
|
||||
config = {
|
||||
'n_layer': 24,
|
||||
'n_embd': 1024,
|
||||
'vocab_size': 50277,
|
||||
'ctx_len': 1024
|
||||
}
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer(
|
||||
accelerator='gpu',
|
||||
devices=8,
|
||||
precision='bf16',
|
||||
strategy='deepspeed_stage_2',
|
||||
max_epochs=1
|
||||
)
|
||||
|
||||
# Train
|
||||
model = RWKV(config)
|
||||
trainer.fit(model, train_dataloader)
|
||||
```
|
||||
|
||||
### Workflow 4: RWKV vs Transformer comparison
|
||||
|
||||
**Memory comparison** (1M token sequence):
|
||||
```python
|
||||
# Transformer (GPT)
|
||||
# Memory: O(n²) for attention
|
||||
# KV cache: 1M × hidden_dim × n_layers × 2 (keys + values)
|
||||
# Example: 1M × 4096 × 24 × 2 = ~400GB (impractical!)
|
||||
|
||||
# RWKV
|
||||
# Memory: O(1) per token
|
||||
# State: hidden_dim × n_layers = 4096 × 24 = ~400KB
|
||||
# 1,000,000× more efficient!
|
||||
```
|
||||
|
||||
**Speed comparison** (inference):
|
||||
```python
|
||||
# Transformer: O(n) per token (quadratic overall)
|
||||
# First token: 1 computation
|
||||
# Second token: 2 computations
|
||||
# ...
|
||||
# 1000th token: 1000 computations
|
||||
|
||||
# RWKV: O(1) per token (linear overall)
|
||||
# Every token: 1 computation
|
||||
# 1000th token: 1 computation (same as first!)
|
||||
```
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use RWKV when**:
|
||||
- Need very long context (100K+ tokens)
|
||||
- Want constant memory usage
|
||||
- Building streaming applications
|
||||
- Need RNN efficiency with Transformer performance
|
||||
- Memory-constrained deployment
|
||||
|
||||
**Key advantages**:
|
||||
- **Linear time**: O(n) vs O(n²) for Transformers
|
||||
- **No KV cache**: Constant memory per token
|
||||
- **Infinite context**: No fixed window limit
|
||||
- **Parallelizable training**: Like GPT
|
||||
- **Sequential inference**: Like RNN
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **Transformers**: Need absolute best performance, have compute
|
||||
- **Mamba**: Want state-space models
|
||||
- **RetNet**: Need retention mechanism
|
||||
- **Hyena**: Want convolution-based approach
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: Out of memory during training**
|
||||
|
||||
Use gradient checkpointing and DeepSpeed:
|
||||
```python
|
||||
trainer = pl.Trainer(
|
||||
strategy='deepspeed_stage_3', # Full ZeRO-3
|
||||
precision='bf16'
|
||||
)
|
||||
```
|
||||
|
||||
**Issue: Slow inference**
|
||||
|
||||
Enable CUDA kernel:
|
||||
```python
|
||||
os.environ["RWKV_CUDA_ON"] = '1'
|
||||
```
|
||||
|
||||
**Issue: Model not loading**
|
||||
|
||||
Check model path and strategy:
|
||||
```python
|
||||
model = RWKV(
|
||||
model='/absolute/path/to/model.pth',
|
||||
strategy='cuda fp16' # Or 'cpu fp32' for CPU
|
||||
)
|
||||
```
|
||||
|
||||
**Issue: State management in RNN mode**
|
||||
|
||||
Always pass state between forward calls:
|
||||
```python
|
||||
# WRONG: State lost
|
||||
out1, _ = model.forward(tokens1, None)
|
||||
out2, _ = model.forward(tokens2, None) # No context from tokens1!
|
||||
|
||||
# CORRECT: State preserved
|
||||
out1, state = model.forward(tokens1, None)
|
||||
out2, state = model.forward(tokens2, state) # Has context from tokens1
|
||||
```
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**Time-mixing and channel-mixing**: See [references/architecture-details.md](references/architecture-details.md) for WKV operation, time-decay mechanism, and receptance gates.
|
||||
|
||||
**State management**: See [references/state-management.md](references/state-management.md) for att_x_prev, att_kv, ffn_x_prev states, and numerical stability considerations.
|
||||
|
||||
**RWKV-7 improvements**: See [references/rwkv7.md](references/rwkv7.md) for latest architectural improvements (March 2025) and multimodal capabilities.
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **GPU**: NVIDIA (CUDA 11.6+) or CPU
|
||||
- **VRAM** (FP16):
|
||||
- 169M model: 1GB
|
||||
- 430M model: 2GB
|
||||
- 1.5B model: 4GB
|
||||
- 3B model: 8GB
|
||||
- 7B model: 16GB
|
||||
- 14B model: 32GB
|
||||
- **Inference**: O(1) memory per token
|
||||
- **Training**: Parallelizable like GPT
|
||||
|
||||
**Performance** (vs Transformers):
|
||||
- **Speed**: Similar training, faster inference
|
||||
- **Memory**: 1000× less for long sequences
|
||||
- **Scaling**: Linear vs quadratic
|
||||
|
||||
## Resources
|
||||
|
||||
- Paper (RWKV): https://arxiv.org/abs/2305.13048 (May 2023)
|
||||
- Paper (RWKV-7): https://arxiv.org/abs/2503.14456 (March 2025)
|
||||
- GitHub: https://github.com/BlinkDL/RWKV-LM ⭐ 12,000+
|
||||
- Docs: https://wiki.rwkv.com/
|
||||
- Models: https://huggingface.co/BlinkDL
|
||||
- Linux Foundation AI: Official project
|
||||
- Production: Microsoft Windows, Office integration, NeMo support
|
||||
|
||||
|
||||
@@ -0,0 +1,344 @@
|
||||
# RWKV Architecture Details
|
||||
|
||||
## Time-Mixing and Channel-Mixing Blocks
|
||||
|
||||
RWKV alternates between **Time-Mixing** (sequence processing) and **Channel-Mixing** (feature processing) blocks.
|
||||
|
||||
### Time-Mixing Block (WKV Operation)
|
||||
|
||||
The core innovation is the **WKV (Weighted Key-Value)** mechanism:
|
||||
|
||||
```python
|
||||
# Traditional Attention (O(n²))
|
||||
scores = Q @ K.T / sqrt(d) # n×n matrix
|
||||
attention = softmax(scores)
|
||||
output = attention @ V
|
||||
|
||||
# RWKV Time-Mixing (O(n))
|
||||
# Compute WKV in linear time using recurrence
|
||||
for t in range(T):
|
||||
wkv[t] = (exp(w) * k[t] @ v[t] + a[t] * aa[t]) / (exp(w) * k[t] + a[t] * ab[t])
|
||||
aa[t+1] = exp(w) * k[t] @ v[t] + exp(-u) * aa[t]
|
||||
ab[t+1] = exp(w) * k[t] + exp(-u) * ab[t]
|
||||
```
|
||||
|
||||
**Full Time-Mixing implementation**:
|
||||
|
||||
```python
|
||||
class RWKV_TimeMix(nn.Module):
|
||||
def __init__(self, d_model, n_layer):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
|
||||
# Linear projections
|
||||
self.key = nn.Linear(d_model, d_model, bias=False)
|
||||
self.value = nn.Linear(d_model, d_model, bias=False)
|
||||
self.receptance = nn.Linear(d_model, d_model, bias=False)
|
||||
self.output = nn.Linear(d_model, d_model, bias=False)
|
||||
|
||||
# Time-mixing parameters
|
||||
self.time_mix_k = nn.Parameter(torch.ones(1, 1, d_model))
|
||||
self.time_mix_v = nn.Parameter(torch.ones(1, 1, d_model))
|
||||
self.time_mix_r = nn.Parameter(torch.ones(1, 1, d_model))
|
||||
|
||||
# Time-decay and bonus
|
||||
self.time_decay = nn.Parameter(torch.ones(d_model)) # w
|
||||
self.time_first = nn.Parameter(torch.ones(d_model)) # u
|
||||
|
||||
def forward(self, x, state=None):
|
||||
B, T, C = x.shape
|
||||
|
||||
# Time-shift mixing (interpolate with previous token)
|
||||
if state is None:
|
||||
state = torch.zeros(B, C, 3, device=x.device) # [aa, ab, x_prev]
|
||||
|
||||
x_prev = state[:, :, 2].unsqueeze(1) # Previous x
|
||||
xk = x * self.time_mix_k + x_prev * (1 - self.time_mix_k)
|
||||
xv = x * self.time_mix_v + x_prev * (1 - self.time_mix_v)
|
||||
xr = x * self.time_mix_r + x_prev * (1 - self.time_mix_r)
|
||||
|
||||
# Compute k, v, r
|
||||
k = self.key(xk)
|
||||
v = self.value(xv)
|
||||
r = self.receptance(xr)
|
||||
|
||||
# WKV computation (parallelizable or sequential)
|
||||
wkv = self.wkv(k, v, state[:, :, :2])
|
||||
|
||||
# Apply receptance gate and output projection
|
||||
out = self.output(torch.sigmoid(r) * wkv)
|
||||
|
||||
# Update state
|
||||
new_state = torch.stack([state_aa, state_ab, x[:, -1]], dim=2)
|
||||
|
||||
return out, new_state
|
||||
|
||||
def wkv(self, k, v, state):
|
||||
# Parallel implementation (training)
|
||||
# Sequential implementation (inference) - see below
|
||||
...
|
||||
```
|
||||
|
||||
### WKV Parallel Algorithm (Training)
|
||||
|
||||
```python
|
||||
def wkv_forward(w, u, k, v):
|
||||
"""
|
||||
Parallel WKV computation for training.
|
||||
w: time_decay (d_model,)
|
||||
u: time_first (d_model,)
|
||||
k: keys (batch, seq_len, d_model)
|
||||
v: values (batch, seq_len, d_model)
|
||||
"""
|
||||
B, T, C = k.shape
|
||||
|
||||
# Compute cumulative sums with exponential decay
|
||||
# This is the key to O(n) parallel computation
|
||||
w = -torch.exp(w) # Negative for decay
|
||||
|
||||
# Associative scan operation
|
||||
wkv = torch.zeros(B, T, C, device=k.device)
|
||||
state = torch.zeros(B, C, device=k.device)
|
||||
|
||||
for t in range(T):
|
||||
kv = k[:, t] * v[:, t]
|
||||
wkv[:, t] = (u * kv + state) / (u * k[:, t] + torch.exp(state_count))
|
||||
state = w * state + kv
|
||||
|
||||
return wkv
|
||||
```
|
||||
|
||||
### WKV Sequential Algorithm (Inference)
|
||||
|
||||
```python
|
||||
def wkv_inference(w, u, k, v, state):
|
||||
"""
|
||||
Sequential WKV for O(1) per-token inference.
|
||||
state: (aa, ab) from previous step
|
||||
"""
|
||||
w = -torch.exp(w) # time_decay
|
||||
u = torch.exp(u) # time_first
|
||||
|
||||
# Unpack state
|
||||
aa, ab = state # aa = numerator, ab = denominator
|
||||
|
||||
# Compute WKV for current token
|
||||
kv = k * v
|
||||
wkv = (u * kv + aa) / (u * k + ab)
|
||||
|
||||
# Update state for next token
|
||||
new_aa = w * aa + kv
|
||||
new_ab = w * ab + k
|
||||
|
||||
return wkv, (new_aa, new_ab)
|
||||
```
|
||||
|
||||
### Channel-Mixing Block
|
||||
|
||||
Replaces Transformer FFN with time-shifted variant:
|
||||
|
||||
```python
|
||||
class RWKV_ChannelMix(nn.Module):
|
||||
def __init__(self, d_model, hidden_ratio=4):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.hidden = d_model * hidden_ratio
|
||||
|
||||
# Time-mixing for channel
|
||||
self.time_mix_k = nn.Parameter(torch.ones(1, 1, d_model))
|
||||
self.time_mix_r = nn.Parameter(torch.ones(1, 1, d_model))
|
||||
|
||||
# FFN layers
|
||||
self.key = nn.Linear(d_model, self.hidden, bias=False)
|
||||
self.receptance = nn.Linear(d_model, d_model, bias=False)
|
||||
self.value = nn.Linear(self.hidden, d_model, bias=False)
|
||||
|
||||
def forward(self, x, x_prev):
|
||||
# Time-shift mixing
|
||||
xk = x * self.time_mix_k + x_prev * (1 - self.time_mix_k)
|
||||
xr = x * self.time_mix_r + x_prev * (1 - self.time_mix_r)
|
||||
|
||||
# Channel mixing
|
||||
k = self.key(xk)
|
||||
k = torch.square(torch.relu(k)) # Squared ReLU activation
|
||||
kv = self.value(k)
|
||||
|
||||
# Receptance gate
|
||||
r = torch.sigmoid(self.receptance(xr))
|
||||
|
||||
return r * kv
|
||||
```
|
||||
|
||||
## RWKV Block Structure
|
||||
|
||||
```python
|
||||
class RWKV_Block(nn.Module):
|
||||
def __init__(self, d_model, n_layer):
|
||||
super().__init__()
|
||||
self.ln1 = nn.LayerNorm(d_model)
|
||||
self.ln2 = nn.LayerNorm(d_model)
|
||||
self.att = RWKV_TimeMix(d_model, n_layer)
|
||||
self.ffn = RWKV_ChannelMix(d_model)
|
||||
|
||||
def forward(self, x, state):
|
||||
# Time-mixing with residual
|
||||
att_out, new_state = self.att(self.ln1(x), state)
|
||||
x = x + att_out
|
||||
|
||||
# Channel-mixing with residual
|
||||
ffn_out = self.ffn(self.ln2(x), state[:, :, 2]) # Use x_prev from state
|
||||
x = x + ffn_out
|
||||
|
||||
return x, new_state
|
||||
|
||||
# Full RWKV model
|
||||
model = nn.Sequential(
|
||||
Embedding(...),
|
||||
*[RWKV_Block(d_model, i) for i in range(n_layers)],
|
||||
LayerNorm(d_model),
|
||||
LMHead(...)
|
||||
)
|
||||
```
|
||||
|
||||
## Time-Decay Mechanism
|
||||
|
||||
The **time_decay** parameter `w` controls how fast information decays:
|
||||
|
||||
```python
|
||||
# Initialization (RWKV-4)
|
||||
time_decay = torch.ones(n_layers, d_model)
|
||||
for i in range(n_layers):
|
||||
for j in range(d_model):
|
||||
# Logarithmic spacing
|
||||
ratio = (i + 1) / n_layers
|
||||
time_decay[i, j] = -5.0 + 8.0 * ratio + 0.3 * (j / d_model)
|
||||
|
||||
# Effect on memory
|
||||
w = -exp(time_decay) # Range: [-exp(-5), -exp(3)] ≈ [-0.007, -20]
|
||||
# Smaller w = slower decay = longer memory
|
||||
# Larger w = faster decay = shorter memory
|
||||
```
|
||||
|
||||
**Layer-wise decay pattern**:
|
||||
- Early layers (shallow): Fast decay, capture local patterns
|
||||
- Later layers (deep): Slow decay, capture long-range dependencies
|
||||
|
||||
## Receptance Gate
|
||||
|
||||
The **receptance** mechanism controls information flow:
|
||||
|
||||
```python
|
||||
r = sigmoid(receptance(x)) # Range [0, 1]
|
||||
output = r * wkv # Gate the WKV output
|
||||
|
||||
# High receptance (r ≈ 1): Pass information through
|
||||
# Low receptance (r ≈ 0): Block information
|
||||
```
|
||||
|
||||
**Purpose**: Similar to LSTM forget gate, but learned per-token
|
||||
|
||||
## RWKV-4 vs RWKV-5 vs RWKV-6 vs RWKV-7
|
||||
|
||||
### RWKV-4 (Original)
|
||||
```python
|
||||
# Time-shift with previous token
|
||||
xx = x * time_mix + x_prev * (1 - time_mix)
|
||||
k, v, r = key(xx), value(xx), receptance(xx)
|
||||
```
|
||||
|
||||
### RWKV-5 (2023)
|
||||
```python
|
||||
# Separate time-mix for k, v, r
|
||||
xk = x * time_mix_k + x_prev * (1 - time_mix_k)
|
||||
xv = x * time_mix_v + x_prev * (1 - time_mix_v)
|
||||
xr = x * time_mix_r + x_prev * (1 - time_mix_r)
|
||||
k, v, r = key(xk), value(xk), receptance(xr)
|
||||
```
|
||||
|
||||
### RWKV-6 (2024)
|
||||
- Added **multi-head time-mixing** (like multi-head attention)
|
||||
- Separate time-decay per head
|
||||
- Improved stability for large models
|
||||
|
||||
```python
|
||||
# Per-head processing
|
||||
for h in range(n_heads):
|
||||
k_h = key[h](x) # Separate projection per head
|
||||
w_h = time_decay[h] # Separate decay per head
|
||||
wkv_h = wkv(k_h, v_h, w_h)
|
||||
output = concat(wkv_0, wkv_1, ..., wkv_H)
|
||||
```
|
||||
|
||||
### RWKV-7 (March 2025)
|
||||
- **Multimodal support** (vision + language)
|
||||
- Improved numerical stability
|
||||
- Better scaling to 14B+ parameters
|
||||
|
||||
## Numerical Stability
|
||||
|
||||
### Issue: Exponential Overflow
|
||||
|
||||
```python
|
||||
# Problem: exp(wkv) can overflow
|
||||
wkv = exp(u * kv) / exp(u * k) # Can overflow!
|
||||
```
|
||||
|
||||
### Solution: Log-space Computation
|
||||
|
||||
```python
|
||||
# Stable implementation
|
||||
log_wkv_num = u + log(kv) + log(aa)
|
||||
log_wkv_den = u + log(k) + log(ab)
|
||||
wkv = exp(log_wkv_num - log_wkv_den) # Numerically stable
|
||||
```
|
||||
|
||||
### Gradient Clipping
|
||||
|
||||
```python
|
||||
# Recommended for training stability
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
```
|
||||
|
||||
## State Management
|
||||
|
||||
### State Shape
|
||||
|
||||
```python
|
||||
# For batch inference
|
||||
state = torch.zeros(
|
||||
batch_size,
|
||||
n_layers,
|
||||
4, # (att_aa, att_ab, att_x_prev, ffn_x_prev)
|
||||
d_model
|
||||
)
|
||||
```
|
||||
|
||||
### State Initialization
|
||||
|
||||
```python
|
||||
# Zero initialization (standard)
|
||||
state = None # Model creates zero state
|
||||
|
||||
# Warm state (from previous conversation)
|
||||
_, state = model.forward(previous_context, None)
|
||||
# Use `state` for next turn
|
||||
```
|
||||
|
||||
### State Serialization
|
||||
|
||||
```python
|
||||
# Save conversation state
|
||||
torch.save(state, 'conversation_state.pt')
|
||||
|
||||
# Resume conversation
|
||||
state = torch.load('conversation_state.pt')
|
||||
out, state = model.forward(new_tokens, state)
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- Paper (RWKV): https://arxiv.org/abs/2305.13048 (May 2023)
|
||||
- Paper (RWKV-7): https://arxiv.org/abs/2503.14456 (March 2025)
|
||||
- GitHub: https://github.com/BlinkDL/RWKV-LM
|
||||
- Math derivation: https://wiki.rwkv.com/
|
||||
- CUDA kernels: https://github.com/BlinkDL/RWKV-CUDA
|
||||
@@ -0,0 +1,386 @@
|
||||
# RWKV-7: Latest Improvements (March 2025)
|
||||
|
||||
## Overview
|
||||
|
||||
RWKV-7 is the latest version released in March 2025, introducing multimodal capabilities and improved scaling to 14B+ parameters.
|
||||
|
||||
**Paper**: https://arxiv.org/abs/2503.14456 (March 2025)
|
||||
|
||||
## Key Improvements Over RWKV-6
|
||||
|
||||
### 1. Enhanced Numerical Stability
|
||||
|
||||
**Problem in RWKV-6**:
|
||||
```python
|
||||
# Exponential operations could overflow for large models
|
||||
att_aa = exp(w) * att_aa + k * v # Overflow risk!
|
||||
```
|
||||
|
||||
**RWKV-7 Solution**:
|
||||
```python
|
||||
# Log-space computation with safe exponentiation
|
||||
log_att_aa = log_softmax([log(k * v), log_w + log(att_aa)])
|
||||
att_aa = exp(log_att_aa)
|
||||
```
|
||||
|
||||
**Result**: Stable training up to 14B parameters (RWKV-6 struggled beyond 7B)
|
||||
|
||||
### 2. Improved Time-Decay Initialization
|
||||
|
||||
**RWKV-6**:
|
||||
```python
|
||||
# Simple logarithmic spacing
|
||||
time_decay[i] = -5.0 + 8.0 * (i / n_layers)
|
||||
```
|
||||
|
||||
**RWKV-7**:
|
||||
```python
|
||||
# Adaptive per-head decay with better range
|
||||
for layer in range(n_layers):
|
||||
for head in range(n_heads):
|
||||
# Different heads specialize in different timescales
|
||||
alpha = (layer / n_layers) ** 0.7 # Non-linear progression
|
||||
beta = (head / n_heads) * 0.5
|
||||
time_decay[layer, head] = -6.0 + 9.0 * alpha + beta
|
||||
|
||||
# Result: Better long/short-term memory balance
|
||||
```
|
||||
|
||||
**Impact**: 15-20% perplexity improvement on long-context tasks
|
||||
|
||||
### 3. Multi-Head Time-Mixing Refinements
|
||||
|
||||
**RWKV-6 Multi-Head**:
|
||||
```python
|
||||
# Simple concatenation
|
||||
heads = [head_i(x) for head_i in heads]
|
||||
output = concat(heads)
|
||||
```
|
||||
|
||||
**RWKV-7 Multi-Head**:
|
||||
```python
|
||||
# Attention-style output projection
|
||||
heads = [head_i(x) for head_i in heads]
|
||||
concat_heads = concat(heads)
|
||||
output = output_proj(concat_heads) # Learnable mixing
|
||||
|
||||
# Plus: Per-head layer norm
|
||||
for i, head in enumerate(heads):
|
||||
heads[i] = head_norm[i](head) # Separate norm per head
|
||||
```
|
||||
|
||||
**Result**: Better head specialization, 8-12% quality improvement
|
||||
|
||||
### 4. Rotary Position Encoding (RoPE) Integration
|
||||
|
||||
**New in RWKV-7**:
|
||||
```python
|
||||
class RWKV7_TimeMix(nn.Module):
|
||||
def __init__(self, d_model, n_heads):
|
||||
super().__init__()
|
||||
self.rope = RotaryEmbedding(d_model // n_heads)
|
||||
|
||||
def forward(self, x):
|
||||
k = self.key(x) # (B, T, d_model)
|
||||
v = self.value(x)
|
||||
|
||||
# Apply RoPE to keys
|
||||
k = self.rope.rotate_queries_or_keys(k)
|
||||
|
||||
# WKV with position-aware keys
|
||||
wkv = self.wkv(k, v)
|
||||
return wkv
|
||||
```
|
||||
|
||||
**Why useful**: Improves positional awareness without breaking O(n) complexity
|
||||
|
||||
### 5. RWKV-7 Block Structure
|
||||
|
||||
```python
|
||||
class RWKV7_Block(nn.Module):
|
||||
def __init__(self, d_model, n_heads):
|
||||
super().__init__()
|
||||
self.ln1 = nn.LayerNorm(d_model)
|
||||
self.ln2 = nn.LayerNorm(d_model)
|
||||
|
||||
# Multi-head time-mixing with RoPE
|
||||
self.att = RWKV7_MultiHeadTimeMix(d_model, n_heads)
|
||||
|
||||
# Enhanced channel-mixing
|
||||
self.ffn = RWKV7_ChannelMix(d_model, hidden_ratio=3.5) # Larger FFN
|
||||
|
||||
def forward(self, x, state):
|
||||
# Pre-norm (like GPT)
|
||||
att_out, new_state = self.att(self.ln1(x), state)
|
||||
x = x + att_out
|
||||
|
||||
# FFN with gating
|
||||
ffn_out = self.ffn(self.ln2(x))
|
||||
x = x + ffn_out
|
||||
|
||||
return x, new_state
|
||||
```
|
||||
|
||||
## Multimodal Capabilities
|
||||
|
||||
### Vision Encoder Integration
|
||||
|
||||
**Architecture**:
|
||||
```python
|
||||
class RWKV7_Multimodal(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Vision encoder (CLIP-style)
|
||||
self.vision_encoder = VisionTransformer(
|
||||
patch_size=14,
|
||||
d_model=1024,
|
||||
n_layers=24
|
||||
)
|
||||
|
||||
# Projection to RWKV space
|
||||
self.vision_proj = nn.Linear(1024, d_model)
|
||||
|
||||
# RWKV language model
|
||||
self.rwkv = RWKV7_LanguageModel(d_model=2560, n_layers=40)
|
||||
|
||||
def forward(self, image, text, state=None):
|
||||
# Encode image to patches
|
||||
vision_tokens = self.vision_encoder(image) # (B, 256, 1024)
|
||||
vision_tokens = self.vision_proj(vision_tokens) # (B, 256, 2560)
|
||||
|
||||
# Concatenate vision and text tokens
|
||||
combined = torch.cat([vision_tokens, text], dim=1)
|
||||
|
||||
# Process with RWKV
|
||||
out, state = self.rwkv(combined, state)
|
||||
|
||||
return out, state
|
||||
```
|
||||
|
||||
### Vision-Language Tasks
|
||||
|
||||
**Image Captioning**:
|
||||
```python
|
||||
model = RWKV7_Multimodal()
|
||||
|
||||
# Encode image
|
||||
image = load_image('cat.jpg')
|
||||
vision_tokens = model.vision_encoder(image)
|
||||
|
||||
# Generate caption
|
||||
state = None
|
||||
_, state = model.rwkv(vision_tokens, state) # Process image
|
||||
|
||||
# Autoregressive caption generation
|
||||
caption = []
|
||||
for _ in range(max_length):
|
||||
logits, state = model.rwkv(prev_token, state)
|
||||
next_token = sample(logits)
|
||||
caption.append(next_token)
|
||||
```
|
||||
|
||||
**VQA (Visual Question Answering)**:
|
||||
```python
|
||||
# Question: "What color is the cat?"
|
||||
question_tokens = tokenizer.encode("What color is the cat?")
|
||||
|
||||
# Process image + question
|
||||
combined = torch.cat([vision_tokens, question_tokens], dim=1)
|
||||
answer_logits, state = model.rwkv(combined, state)
|
||||
|
||||
# Answer: "orange"
|
||||
```
|
||||
|
||||
### Training Multimodal RWKV-7
|
||||
|
||||
```python
|
||||
# Pretrain vision encoder (CLIP-style)
|
||||
train_vision_encoder(image_text_pairs)
|
||||
|
||||
# Freeze vision encoder
|
||||
model.vision_encoder.requires_grad_(False)
|
||||
|
||||
# Train projection + RWKV
|
||||
for batch in multimodal_dataloader:
|
||||
images, captions = batch
|
||||
|
||||
# Forward
|
||||
vision_tokens = model.vision_encoder(images)
|
||||
vision_tokens = model.vision_proj(vision_tokens)
|
||||
|
||||
logits, _ = model.rwkv(
|
||||
torch.cat([vision_tokens, captions[:, :-1]], dim=1),
|
||||
state=None
|
||||
)
|
||||
|
||||
# Loss (next token prediction)
|
||||
loss = F.cross_entropy(
|
||||
logits[:, vision_tokens.shape[1]:].reshape(-1, vocab_size),
|
||||
captions.reshape(-1)
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
## Scaling to 14B Parameters
|
||||
|
||||
### Model Configuration
|
||||
|
||||
| Model | Layers | d_model | n_heads | Params | Context | VRAM (FP16) |
|
||||
|-------|--------|---------|---------|--------|---------|-------------|
|
||||
| RWKV-7-1.5B | 24 | 2048 | 16 | 1.5B | Infinite | 3 GB |
|
||||
| RWKV-7-3B | 32 | 2560 | 20 | 3B | Infinite | 6 GB |
|
||||
| RWKV-7-7B | 32 | 4096 | 32 | 7B | Infinite | 14 GB |
|
||||
| RWKV-7-14B | 40 | 5120 | 40 | 14B | Infinite | 28 GB |
|
||||
|
||||
### Training Efficiency Improvements
|
||||
|
||||
**RWKV-6 Training (7B)**:
|
||||
- Speed: 45K tokens/sec (8× A100)
|
||||
- Memory: 38 GB per GPU (4K sequence)
|
||||
- Stability: Occasional loss spikes
|
||||
|
||||
**RWKV-7 Training (14B)**:
|
||||
- Speed: 52K tokens/sec (8× A100) - **15% faster**
|
||||
- Memory: 42 GB per GPU (4K sequence) - **Better utilization**
|
||||
- Stability: No loss spikes - **Improved stability**
|
||||
|
||||
**Key optimization**: Fused CUDA kernels for multi-head WKV
|
||||
|
||||
### RWKV-7 vs GPT-3 (14B)
|
||||
|
||||
| Metric | RWKV-7-14B | GPT-3-13B | Advantage |
|
||||
|--------|------------|-----------|-----------|
|
||||
| Training Speed | 52K tok/s | 28K tok/s | 1.9× |
|
||||
| Inference (2K ctx) | 6,100 tok/s | 1,800 tok/s | 3.4× |
|
||||
| Inference (8K ctx) | 5,800 tok/s | 450 tok/s | **12.9×** |
|
||||
| Memory (inference) | 28 GB | 52 GB | 1.9× |
|
||||
| Perplexity (Pile) | 6.8 | 7.2 | +6% |
|
||||
|
||||
## Production Use Cases
|
||||
|
||||
### Microsoft Integration
|
||||
|
||||
**Windows Copilot** (Limited Release):
|
||||
- Uses RWKV-7-3B for on-device inference
|
||||
- 5-8× faster than GPT-2 with better quality
|
||||
- Constant memory for infinite context
|
||||
|
||||
**Office 365** (Experimental):
|
||||
- Document summarization with RWKV-7-7B
|
||||
- Handles 100K+ token documents efficiently
|
||||
- No KV cache storage needed
|
||||
|
||||
### NVIDIA NeMo Support
|
||||
|
||||
**NeMo Guardrails with RWKV-7**:
|
||||
```python
|
||||
from nemoguardrails import RailsConfig
|
||||
from nemoguardrails.llm.providers import register_llm_provider
|
||||
|
||||
# Register RWKV-7 as LLM backend
|
||||
register_llm_provider("rwkv7", RWKV7Provider)
|
||||
|
||||
config = RailsConfig.from_path("config/")
|
||||
rails = LLMRails(config, llm_provider="rwkv7")
|
||||
|
||||
# Use for content moderation
|
||||
response = rails.generate(user_input="...")
|
||||
```
|
||||
|
||||
## Benchmarks (RWKV-7 vs RWKV-6)
|
||||
|
||||
### Language Modeling
|
||||
|
||||
| Dataset | RWKV-6-7B | RWKV-7-7B | Improvement |
|
||||
|---------|-----------|-----------|-------------|
|
||||
| Pile (val) | 7.8 | 7.1 | +9% |
|
||||
| C4 | 9.3 | 8.6 | +8% |
|
||||
| WikiText-103 | 8.4 | 7.7 | +8% |
|
||||
| Lambada | 11.2 | 9.8 | +13% |
|
||||
|
||||
### Long-Context Tasks (32K context)
|
||||
|
||||
| Task | RWKV-6-7B | RWKV-7-7B | Improvement |
|
||||
|------|-----------|-----------|-------------|
|
||||
| QuALITY | 52.3 | 61.8 | +18% |
|
||||
| Qasper | 38.1 | 46.7 | +23% |
|
||||
| NarrativeQA | 41.2 | 49.5 | +20% |
|
||||
|
||||
**RWKV-7's improved time-decay** significantly helps long-context understanding
|
||||
|
||||
### Multimodal Benchmarks
|
||||
|
||||
| Task | RWKV-7-7B | LLaVA-7B | BLIP-2-7B |
|
||||
|------|-----------|----------|-----------|
|
||||
| VQAv2 | 74.2 | 78.5 | 82.1 |
|
||||
| GQA | 58.3 | 62.1 | 65.4 |
|
||||
| TextVQA | 51.2 | 58.2 | 60.8 |
|
||||
| COCO Caption | 118.3 | 125.7 | 132.4 |
|
||||
|
||||
**Note**: RWKV-7 competitive but not SOTA on vision (vision-focused models still better)
|
||||
|
||||
## Migration from RWKV-6 to RWKV-7
|
||||
|
||||
### Model Conversion
|
||||
|
||||
```python
|
||||
# Load RWKV-6 checkpoint
|
||||
rwkv6_state = torch.load('rwkv6-7b.pth')
|
||||
|
||||
# Initialize RWKV-7 model
|
||||
rwkv7_model = RWKV7_Model(d_model=4096, n_layers=32, n_heads=32)
|
||||
|
||||
# Convert weights (mostly compatible)
|
||||
for key in rwkv6_state:
|
||||
if 'time_mixing' in key:
|
||||
# RWKV-7 uses multi-head, need to split
|
||||
rwkv7_key = convert_key_to_multihead(key)
|
||||
rwkv7_model.state_dict()[rwkv7_key].copy_(rwkv6_state[key])
|
||||
else:
|
||||
# Direct copy
|
||||
rwkv7_model.state_dict()[key].copy_(rwkv6_state[key])
|
||||
|
||||
# Fine-tune on small dataset to adapt
|
||||
finetune(rwkv7_model, small_dataset, epochs=1)
|
||||
```
|
||||
|
||||
### State Compatibility
|
||||
|
||||
**RWKV-6 State**:
|
||||
```python
|
||||
state_v6 = (att_aa, att_ab, att_x_prev, ffn_x_prev) # 4 components
|
||||
```
|
||||
|
||||
**RWKV-7 State** (Multi-head):
|
||||
```python
|
||||
state_v7 = (
|
||||
att_aa_heads, # (n_heads, d_model//n_heads)
|
||||
att_ab_heads, # (n_heads, d_model//n_heads)
|
||||
att_x_prev,
|
||||
ffn_x_prev
|
||||
) # 4 components, but att_* are multi-head
|
||||
```
|
||||
|
||||
**Conversion**:
|
||||
```python
|
||||
# Split RWKV-6 state into RWKV-7 multi-head state
|
||||
def convert_state_v6_to_v7(state_v6, n_heads):
|
||||
att_aa, att_ab, att_x_prev, ffn_x_prev = state_v6
|
||||
d_head = att_aa.shape[-1] // n_heads
|
||||
|
||||
att_aa_heads = att_aa.view(-1, n_heads, d_head).transpose(0, 1)
|
||||
att_ab_heads = att_ab.view(-1, n_heads, d_head).transpose(0, 1)
|
||||
|
||||
return (att_aa_heads, att_ab_heads, att_x_prev, ffn_x_prev)
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Paper**: https://arxiv.org/abs/2503.14456 (RWKV-7, March 2025)
|
||||
- **GitHub**: https://github.com/BlinkDL/RWKV-LM (v7 branch)
|
||||
- **Models**: https://huggingface.co/BlinkDL/rwkv-7-world
|
||||
- **Multimodal Demo**: https://huggingface.co/spaces/BlinkDL/RWKV-7-Multimodal
|
||||
- **Discord**: https://discord.gg/bDSBUMeFpc
|
||||
- **Wiki**: https://wiki.rwkv.com/rwkv7
|
||||
@@ -0,0 +1,369 @@
|
||||
# RWKV State Management
|
||||
|
||||
## Understanding RWKV State
|
||||
|
||||
Unlike Transformers with KV cache, RWKV maintains a **fixed-size recurrent state** that summarizes all previous context.
|
||||
|
||||
### State Components
|
||||
|
||||
```python
|
||||
state = {
|
||||
'att_aa': torch.zeros(n_layers, d_model), # Attention numerator accumulator
|
||||
'att_ab': torch.zeros(n_layers, d_model), # Attention denominator accumulator
|
||||
'att_x_prev': torch.zeros(n_layers, d_model), # Previous x for time-mixing
|
||||
'ffn_x_prev': torch.zeros(n_layers, d_model) # Previous x for channel-mixing
|
||||
}
|
||||
```
|
||||
|
||||
**Total state size**: `4 × n_layers × d_model` parameters
|
||||
|
||||
| Model | Layers | d_model | State Size |
|
||||
|-------|--------|---------|------------|
|
||||
| RWKV-169M | 12 | 768 | 37 KB |
|
||||
| RWKV-430M | 24 | 1024 | 98 KB |
|
||||
| RWKV-1.5B | 24 | 2048 | 196 KB |
|
||||
| RWKV-3B | 32 | 2560 | 327 KB |
|
||||
| RWKV-7B | 32 | 4096 | 524 KB |
|
||||
| RWKV-14B | 40 | 5120 | 819 KB |
|
||||
|
||||
**Constant memory** regardless of context length!
|
||||
|
||||
## State Initialization
|
||||
|
||||
### Zero State (Default)
|
||||
|
||||
```python
|
||||
from rwkv.model import RWKV
|
||||
|
||||
model = RWKV(model='/path/to/RWKV-4-Pile-1B5', strategy='cuda fp16')
|
||||
|
||||
# Start with zero state (no context)
|
||||
state = None
|
||||
out, state = model.forward(tokens, state)
|
||||
```
|
||||
|
||||
### Warm State (Preloaded Context)
|
||||
|
||||
```python
|
||||
# Load context once
|
||||
context = "The capital of France is Paris. The capital of Germany is Berlin."
|
||||
context_tokens = tokenizer.encode(context)
|
||||
|
||||
# Process context to build state
|
||||
state = None
|
||||
for token in context_tokens:
|
||||
_, state = model.forward([token], state)
|
||||
|
||||
# Now use warm state for queries
|
||||
query = " The capital of Italy is"
|
||||
query_tokens = tokenizer.encode(query)
|
||||
out, state = model.forward(query_tokens, state)
|
||||
# Model "remembers" Paris and Berlin examples!
|
||||
```
|
||||
|
||||
### Shared State (Multi-turn Conversations)
|
||||
|
||||
```python
|
||||
# Conversation with persistent state
|
||||
state = None
|
||||
|
||||
# Turn 1
|
||||
user1 = "My name is Alice."
|
||||
tokens1 = tokenizer.encode(user1)
|
||||
_, state = model.forward(tokens1, state)
|
||||
|
||||
# Turn 2
|
||||
user2 = "What is my name?"
|
||||
tokens2 = tokenizer.encode(user2)
|
||||
response, state = model.forward(tokens2, state)
|
||||
# Response: "Alice" (state remembers!)
|
||||
```
|
||||
|
||||
## State Update Rules
|
||||
|
||||
### Time-Mixing State Update
|
||||
|
||||
```python
|
||||
# Before processing token t
|
||||
att_aa_t = att_aa_{t-1} # Previous numerator
|
||||
att_ab_t = att_ab_{t-1} # Previous denominator
|
||||
|
||||
# Compute WKV
|
||||
wkv_t = (exp(u) * k_t * v_t + att_aa_t) / (exp(u) * k_t + att_ab_t)
|
||||
|
||||
# Update state for token t+1
|
||||
w = -exp(time_decay) # Decay factor
|
||||
att_aa_{t+1} = exp(w) * att_aa_t + k_t * v_t
|
||||
att_ab_{t+1} = exp(w) * att_ab_t + k_t
|
||||
att_x_prev_{t+1} = x_t
|
||||
```
|
||||
|
||||
**Effect of time_decay**:
|
||||
- **w = -0.01** (small decay): State decays slowly → long memory
|
||||
- **w = -5.0** (large decay): State decays quickly → short memory
|
||||
|
||||
### Channel-Mixing State Update
|
||||
|
||||
```python
|
||||
# Simply store previous x for next token
|
||||
ffn_x_prev_{t+1} = x_t
|
||||
```
|
||||
|
||||
## State Serialization
|
||||
|
||||
### Save/Load State (PyTorch)
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Save conversation state
|
||||
state_dict = {
|
||||
'att_aa': state[0],
|
||||
'att_ab': state[1],
|
||||
'att_x_prev': state[2],
|
||||
'ffn_x_prev': state[3]
|
||||
}
|
||||
torch.save(state_dict, 'conversation_123.pt')
|
||||
|
||||
# Load state
|
||||
loaded = torch.load('conversation_123.pt')
|
||||
state = (loaded['att_aa'], loaded['att_ab'], loaded['att_x_prev'], loaded['ffn_x_prev'])
|
||||
|
||||
# Continue conversation
|
||||
out, state = model.forward(new_tokens, state)
|
||||
```
|
||||
|
||||
### State Compression (Optional)
|
||||
|
||||
```python
|
||||
# FP16 state (half size)
|
||||
state_fp16 = tuple(s.half() for s in state)
|
||||
torch.save(state_fp16, 'state_compressed.pt')
|
||||
|
||||
# Restore
|
||||
state = tuple(s.float() for s in torch.load('state_compressed.pt'))
|
||||
```
|
||||
|
||||
## Multi-Session State Management
|
||||
|
||||
### Session State Store
|
||||
|
||||
```python
|
||||
class StateManager:
|
||||
def __init__(self):
|
||||
self.sessions = {} # session_id -> state
|
||||
|
||||
def get_state(self, session_id):
|
||||
return self.sessions.get(session_id, None)
|
||||
|
||||
def save_state(self, session_id, state):
|
||||
self.sessions[session_id] = state
|
||||
|
||||
def clear_session(self, session_id):
|
||||
if session_id in self.sessions:
|
||||
del self.sessions[session_id]
|
||||
|
||||
# Usage
|
||||
manager = StateManager()
|
||||
|
||||
# User 1 conversation
|
||||
state1 = manager.get_state('user_1')
|
||||
out1, state1 = model.forward(tokens1, state1)
|
||||
manager.save_state('user_1', state1)
|
||||
|
||||
# User 2 conversation (independent state)
|
||||
state2 = manager.get_state('user_2')
|
||||
out2, state2 = model.forward(tokens2, state2)
|
||||
manager.save_state('user_2', state2)
|
||||
```
|
||||
|
||||
### State Expiration
|
||||
|
||||
```python
|
||||
import time
|
||||
|
||||
class StateManagerWithExpiry:
|
||||
def __init__(self, expiry_seconds=3600):
|
||||
self.sessions = {} # session_id -> (state, timestamp)
|
||||
self.expiry = expiry_seconds
|
||||
|
||||
def get_state(self, session_id):
|
||||
if session_id in self.sessions:
|
||||
state, timestamp = self.sessions[session_id]
|
||||
if time.time() - timestamp < self.expiry:
|
||||
return state
|
||||
else:
|
||||
del self.sessions[session_id] # Expired
|
||||
return None
|
||||
|
||||
def save_state(self, session_id, state):
|
||||
self.sessions[session_id] = (state, time.time())
|
||||
```
|
||||
|
||||
## State Interpolation
|
||||
|
||||
### Blending States
|
||||
|
||||
```python
|
||||
# Average two states (e.g., merging conversations)
|
||||
def blend_states(state1, state2, alpha=0.5):
|
||||
"""Blend state1 and state2 with weight alpha."""
|
||||
return tuple(
|
||||
alpha * s1 + (1 - alpha) * s2
|
||||
for s1, s2 in zip(state1, state2)
|
||||
)
|
||||
|
||||
# Example: Blend Alice and Bob conversation contexts
|
||||
state_blended = blend_states(state_alice, state_bob, alpha=0.7)
|
||||
# 70% Alice context, 30% Bob context
|
||||
```
|
||||
|
||||
### State Editing
|
||||
|
||||
```python
|
||||
# Manually edit state (advanced)
|
||||
# Example: Reduce long-term memory influence
|
||||
|
||||
def decay_state(state, decay_factor=0.5):
|
||||
"""Reduce state magnitude (forget older context)."""
|
||||
att_aa, att_ab, att_x_prev, ffn_x_prev = state
|
||||
return (
|
||||
att_aa * decay_factor,
|
||||
att_ab * decay_factor,
|
||||
att_x_prev, # Keep recent x
|
||||
ffn_x_prev # Keep recent x
|
||||
)
|
||||
|
||||
# Usage
|
||||
state = decay_state(state, decay_factor=0.3) # Forget 70% of history
|
||||
```
|
||||
|
||||
## Batch Inference with States
|
||||
|
||||
### Independent Batch States
|
||||
|
||||
```python
|
||||
# Each sequence in batch has separate state
|
||||
batch_size = 4
|
||||
states = [None] * batch_size
|
||||
|
||||
for i, tokens in enumerate(batch_sequences):
|
||||
out, states[i] = model.forward(tokens, states[i])
|
||||
```
|
||||
|
||||
### Shared Prefix Optimization
|
||||
|
||||
```python
|
||||
# All sequences share common prefix (e.g., system prompt)
|
||||
prefix = "You are a helpful assistant."
|
||||
prefix_tokens = tokenizer.encode(prefix)
|
||||
|
||||
# Compute prefix state once
|
||||
prefix_state = None
|
||||
_, prefix_state = model.forward(prefix_tokens, None)
|
||||
|
||||
# Clone prefix state for each sequence
|
||||
states = [prefix_state] * batch_size
|
||||
|
||||
# Process user queries (independent)
|
||||
for i, user_query in enumerate(user_queries):
|
||||
tokens = tokenizer.encode(user_query)
|
||||
out, states[i] = model.forward(tokens, states[i])
|
||||
```
|
||||
|
||||
## State Debugging
|
||||
|
||||
### Inspect State Magnitudes
|
||||
|
||||
```python
|
||||
def inspect_state(state):
|
||||
"""Print state statistics for debugging."""
|
||||
att_aa, att_ab, att_x_prev, ffn_x_prev = state
|
||||
|
||||
print("State magnitudes:")
|
||||
print(f" att_aa: mean={att_aa.abs().mean():.4f}, max={att_aa.abs().max():.4f}")
|
||||
print(f" att_ab: mean={att_ab.abs().mean():.4f}, max={att_ab.abs().max():.4f}")
|
||||
print(f" att_x_prev: mean={att_x_prev.abs().mean():.4f}, max={att_x_prev.abs().max():.4f}")
|
||||
print(f" ffn_x_prev: mean={ffn_x_prev.abs().mean():.4f}, max={ffn_x_prev.abs().max():.4f}")
|
||||
|
||||
# Usage
|
||||
inspect_state(state)
|
||||
```
|
||||
|
||||
**Healthy ranges**:
|
||||
- `att_aa`, `att_ab`: 0.1 - 10.0 (if much larger, may overflow)
|
||||
- `att_x_prev`, `ffn_x_prev`: Similar to input embedding magnitude
|
||||
|
||||
### State Divergence Check
|
||||
|
||||
```python
|
||||
def state_distance(state1, state2):
|
||||
"""Compute L2 distance between two states."""
|
||||
return sum(
|
||||
torch.dist(s1, s2).item()
|
||||
for s1, s2 in zip(state1, state2)
|
||||
)
|
||||
|
||||
# Example: Check if states diverged
|
||||
distance = state_distance(state_alice, state_bob)
|
||||
print(f"State distance: {distance:.2f}")
|
||||
# Large distance → very different contexts
|
||||
```
|
||||
|
||||
## Numerical Stability Considerations
|
||||
|
||||
### Overflow Prevention
|
||||
|
||||
```python
|
||||
# Issue: att_aa, att_ab can grow unbounded
|
||||
# If att_aa > 1e10, numerical precision issues
|
||||
|
||||
# Solution 1: Periodic normalization
|
||||
if att_aa.abs().max() > 1e6:
|
||||
scale = att_aa.abs().max()
|
||||
att_aa = att_aa / scale
|
||||
att_ab = att_ab / scale
|
||||
```
|
||||
|
||||
### Underflow Prevention
|
||||
|
||||
```python
|
||||
# Issue: With large negative time_decay, state can underflow to 0
|
||||
|
||||
# Solution: Clip time_decay
|
||||
time_decay = torch.clamp(time_decay, min=-8.0, max=-0.1)
|
||||
# Ensures state doesn't decay too fast
|
||||
```
|
||||
|
||||
## State vs KV Cache Comparison
|
||||
|
||||
### Memory Usage (8K context)
|
||||
|
||||
| Model Type | Model Size | KV Cache Size | RWKV State Size |
|
||||
|------------|------------|---------------|-----------------|
|
||||
| Transformer | 1.3B | 4.1 GB | - |
|
||||
| **RWKV** | **1.5B** | **-** | **196 KB** |
|
||||
| Transformer | 7B | 21.3 GB | - |
|
||||
| **RWKV** | **7B** | **-** | **524 KB** |
|
||||
|
||||
**RWKV advantage**: 10,000× smaller than KV cache!
|
||||
|
||||
### Information Retention
|
||||
|
||||
**KV Cache (Transformer)**:
|
||||
- Perfect: Stores all previous keys and values
|
||||
- Retrieval: Exact attention to any previous token
|
||||
- Cost: O(n) memory growth
|
||||
|
||||
**RWKV State**:
|
||||
- Lossy: Compressed representation of history
|
||||
- Retrieval: Weighted blend of previous tokens (decay-based)
|
||||
- Cost: O(1) constant memory
|
||||
|
||||
**Trade-off**: RWKV sacrifices perfect recall for constant memory
|
||||
|
||||
## Resources
|
||||
|
||||
- State management examples: https://github.com/BlinkDL/ChatRWKV
|
||||
- Wiki: https://wiki.rwkv.com/state-management
|
||||
- Discord: https://discord.gg/bDSBUMeFpc (RWKV community)
|
||||
@@ -0,0 +1,358 @@
|
||||
---
|
||||
name: distributed-llm-pretraining-torchtitan
|
||||
description: Provides PyTorch-native distributed LLM pretraining using torchtitan with 4D parallelism (FSDP2, TP, PP, CP). Use when pretraining Llama 3.1, DeepSeek V3, or custom models at scale from 8 to 512+ GPUs with Float8, torch.compile, and distributed checkpointing.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Model Architecture, Distributed Training, TorchTitan, FSDP2, Tensor Parallel, Pipeline Parallel, Context Parallel, Float8, Llama, Pretraining]
|
||||
dependencies: [torch>=2.6.0, torchtitan>=0.2.0, torchao>=0.5.0]
|
||||
---
|
||||
|
||||
# TorchTitan - PyTorch Native Distributed LLM Pretraining
|
||||
|
||||
## Quick start
|
||||
|
||||
TorchTitan is PyTorch's official platform for large-scale LLM pretraining with composable 4D parallelism (FSDP2, TP, PP, CP), achieving 65%+ speedups over baselines on H100 GPUs.
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
# From PyPI (stable)
|
||||
pip install torchtitan
|
||||
|
||||
# From source (latest features, requires PyTorch nightly)
|
||||
git clone https://github.com/pytorch/torchtitan
|
||||
cd torchtitan
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
**Download tokenizer**:
|
||||
```bash
|
||||
# Get HF token from https://huggingface.co/settings/tokens
|
||||
python scripts/download_hf_assets.py --repo_id meta-llama/Llama-3.1-8B --assets tokenizer --hf_token=...
|
||||
```
|
||||
|
||||
**Start training on 8 GPUs**:
|
||||
```bash
|
||||
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Pretrain Llama 3.1 8B on single node
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
Single Node Pretraining:
|
||||
- [ ] Step 1: Download tokenizer
|
||||
- [ ] Step 2: Configure training
|
||||
- [ ] Step 3: Launch training
|
||||
- [ ] Step 4: Monitor and checkpoint
|
||||
```
|
||||
|
||||
**Step 1: Download tokenizer**
|
||||
|
||||
```bash
|
||||
python scripts/download_hf_assets.py \
|
||||
--repo_id meta-llama/Llama-3.1-8B \
|
||||
--assets tokenizer \
|
||||
--hf_token=YOUR_HF_TOKEN
|
||||
```
|
||||
|
||||
**Step 2: Configure training**
|
||||
|
||||
Edit or create a TOML config file:
|
||||
|
||||
```toml
|
||||
# llama3_8b_custom.toml
|
||||
[job]
|
||||
dump_folder = "./outputs"
|
||||
description = "Llama 3.1 8B training"
|
||||
|
||||
[model]
|
||||
name = "llama3"
|
||||
flavor = "8B"
|
||||
hf_assets_path = "./assets/hf/Llama-3.1-8B"
|
||||
|
||||
[optimizer]
|
||||
name = "AdamW"
|
||||
lr = 3e-4
|
||||
|
||||
[lr_scheduler]
|
||||
warmup_steps = 200
|
||||
|
||||
[training]
|
||||
local_batch_size = 2
|
||||
seq_len = 8192
|
||||
max_norm = 1.0
|
||||
steps = 1000
|
||||
dataset = "c4"
|
||||
|
||||
[parallelism]
|
||||
data_parallel_shard_degree = -1 # Use all GPUs for FSDP
|
||||
|
||||
[activation_checkpoint]
|
||||
mode = "selective"
|
||||
selective_ac_option = "op"
|
||||
|
||||
[checkpoint]
|
||||
enable = true
|
||||
folder = "checkpoint"
|
||||
interval = 500
|
||||
```
|
||||
|
||||
**Step 3: Launch training**
|
||||
|
||||
```bash
|
||||
# 8 GPUs on single node
|
||||
CONFIG_FILE="./llama3_8b_custom.toml" ./run_train.sh
|
||||
|
||||
# Or explicitly with torchrun
|
||||
torchrun --nproc_per_node=8 \
|
||||
-m torchtitan.train \
|
||||
--job.config_file ./llama3_8b_custom.toml
|
||||
```
|
||||
|
||||
**Step 4: Monitor and checkpoint**
|
||||
|
||||
TensorBoard logs are saved to `./outputs/tb/`:
|
||||
```bash
|
||||
tensorboard --logdir ./outputs/tb
|
||||
```
|
||||
|
||||
### Workflow 2: Multi-node training with SLURM
|
||||
|
||||
```
|
||||
Multi-Node Training:
|
||||
- [ ] Step 1: Configure parallelism for scale
|
||||
- [ ] Step 2: Set up SLURM script
|
||||
- [ ] Step 3: Submit job
|
||||
- [ ] Step 4: Resume from checkpoint
|
||||
```
|
||||
|
||||
**Step 1: Configure parallelism for scale**
|
||||
|
||||
For 70B model on 256 GPUs (32 nodes):
|
||||
```toml
|
||||
[parallelism]
|
||||
data_parallel_shard_degree = 32 # FSDP across 32 ranks
|
||||
tensor_parallel_degree = 8 # TP within node
|
||||
pipeline_parallel_degree = 1 # No PP for 70B
|
||||
context_parallel_degree = 1 # Increase for long sequences
|
||||
```
|
||||
|
||||
**Step 2: Set up SLURM script**
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=llama70b
|
||||
#SBATCH --nodes=32
|
||||
#SBATCH --ntasks-per-node=8
|
||||
#SBATCH --gpus-per-node=8
|
||||
|
||||
srun torchrun \
|
||||
--nnodes=32 \
|
||||
--nproc_per_node=8 \
|
||||
--rdzv_backend=c10d \
|
||||
--rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \
|
||||
-m torchtitan.train \
|
||||
--job.config_file ./llama3_70b.toml
|
||||
```
|
||||
|
||||
**Step 3: Submit job**
|
||||
|
||||
```bash
|
||||
sbatch multinode_trainer.slurm
|
||||
```
|
||||
|
||||
**Step 4: Resume from checkpoint**
|
||||
|
||||
Training auto-resumes if checkpoint exists in configured folder.
|
||||
|
||||
### Workflow 3: Enable Float8 training for H100s
|
||||
|
||||
Float8 provides 30-50% speedup on H100 GPUs.
|
||||
|
||||
```
|
||||
Float8 Training:
|
||||
- [ ] Step 1: Install torchao
|
||||
- [ ] Step 2: Configure Float8
|
||||
- [ ] Step 3: Launch with compile
|
||||
```
|
||||
|
||||
**Step 1: Install torchao**
|
||||
|
||||
```bash
|
||||
USE_CPP=0 pip install git+https://github.com/pytorch/ao.git
|
||||
```
|
||||
|
||||
**Step 2: Configure Float8**
|
||||
|
||||
Add to your TOML config:
|
||||
```toml
|
||||
[model]
|
||||
converters = ["quantize.linear.float8"]
|
||||
|
||||
[quantize.linear.float8]
|
||||
enable_fsdp_float8_all_gather = true
|
||||
precompute_float8_dynamic_scale_for_fsdp = true
|
||||
filter_fqns = ["output"] # Exclude output layer
|
||||
|
||||
[compile]
|
||||
enable = true
|
||||
components = ["model", "loss"]
|
||||
```
|
||||
|
||||
**Step 3: Launch with compile**
|
||||
|
||||
```bash
|
||||
CONFIG_FILE="./llama3_8b.toml" ./run_train.sh \
|
||||
--model.converters="quantize.linear.float8" \
|
||||
--quantize.linear.float8.enable_fsdp_float8_all_gather \
|
||||
--compile.enable
|
||||
```
|
||||
|
||||
### Workflow 4: 4D parallelism for 405B models
|
||||
|
||||
```
|
||||
4D Parallelism (FSDP + TP + PP + CP):
|
||||
- [ ] Step 1: Create seed checkpoint
|
||||
- [ ] Step 2: Configure 4D parallelism
|
||||
- [ ] Step 3: Launch on 512 GPUs
|
||||
```
|
||||
|
||||
**Step 1: Create seed checkpoint**
|
||||
|
||||
Required for consistent initialization across PP stages:
|
||||
```bash
|
||||
NGPU=1 CONFIG_FILE=./llama3_405b.toml ./run_train.sh \
|
||||
--checkpoint.enable \
|
||||
--checkpoint.create_seed_checkpoint \
|
||||
--parallelism.data_parallel_shard_degree 1 \
|
||||
--parallelism.tensor_parallel_degree 1 \
|
||||
--parallelism.pipeline_parallel_degree 1
|
||||
```
|
||||
|
||||
**Step 2: Configure 4D parallelism**
|
||||
|
||||
```toml
|
||||
[parallelism]
|
||||
data_parallel_shard_degree = 8 # FSDP
|
||||
tensor_parallel_degree = 8 # TP within node
|
||||
pipeline_parallel_degree = 8 # PP across nodes
|
||||
context_parallel_degree = 1 # CP for long sequences
|
||||
|
||||
[training]
|
||||
local_batch_size = 32
|
||||
seq_len = 8192
|
||||
```
|
||||
|
||||
**Step 3: Launch on 512 GPUs**
|
||||
|
||||
```bash
|
||||
# 64 nodes x 8 GPUs = 512 GPUs
|
||||
srun torchrun --nnodes=64 --nproc_per_node=8 \
|
||||
-m torchtitan.train \
|
||||
--job.config_file ./llama3_405b.toml
|
||||
```
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use TorchTitan when:**
|
||||
- Pretraining LLMs from scratch (8B to 405B+)
|
||||
- Need PyTorch-native solution without third-party dependencies
|
||||
- Require composable 4D parallelism (FSDP2, TP, PP, CP)
|
||||
- Training on H100s with Float8 support
|
||||
- Want interoperable checkpoints with torchtune/HuggingFace
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **Megatron-LM**: Maximum performance for NVIDIA-only deployments
|
||||
- **DeepSpeed**: Broader ZeRO optimization ecosystem, inference support
|
||||
- **Axolotl/TRL**: Fine-tuning rather than pretraining
|
||||
- **LitGPT**: Educational, smaller-scale training
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: Out of memory on large models**
|
||||
|
||||
Enable activation checkpointing and reduce batch size:
|
||||
```toml
|
||||
[activation_checkpoint]
|
||||
mode = "full" # Instead of "selective"
|
||||
|
||||
[training]
|
||||
local_batch_size = 1
|
||||
```
|
||||
|
||||
Or use gradient accumulation:
|
||||
```toml
|
||||
[training]
|
||||
local_batch_size = 1
|
||||
global_batch_size = 32 # Accumulates gradients
|
||||
```
|
||||
|
||||
**Issue: TP causes high memory with async collectives**
|
||||
|
||||
Set environment variable:
|
||||
```bash
|
||||
export TORCH_NCCL_AVOID_RECORD_STREAMS=1
|
||||
```
|
||||
|
||||
**Issue: Float8 training not faster**
|
||||
|
||||
Float8 only benefits large GEMMs. Filter small layers:
|
||||
```toml
|
||||
[quantize.linear.float8]
|
||||
filter_fqns = ["attention.wk", "attention.wv", "output", "auto_filter_small_kn"]
|
||||
```
|
||||
|
||||
**Issue: Checkpoint loading fails after parallelism change**
|
||||
|
||||
Use DCP's resharding capability:
|
||||
```bash
|
||||
# Convert sharded checkpoint to single file
|
||||
python -m torch.distributed.checkpoint.format_utils \
|
||||
dcp_to_torch checkpoint/step-1000 checkpoint.pt
|
||||
```
|
||||
|
||||
**Issue: Pipeline parallelism initialization**
|
||||
|
||||
Create seed checkpoint first (see Workflow 4, Step 1).
|
||||
|
||||
## Supported models
|
||||
|
||||
| Model | Sizes | Status |
|
||||
|-------|-------|--------|
|
||||
| Llama 3.1 | 8B, 70B, 405B | Production |
|
||||
| Llama 4 | Various | Experimental |
|
||||
| DeepSeek V3 | 16B, 236B, 671B (MoE) | Experimental |
|
||||
| GPT-OSS | 20B, 120B (MoE) | Experimental |
|
||||
| Qwen 3 | Various | Experimental |
|
||||
| Flux | Diffusion | Experimental |
|
||||
|
||||
## Performance benchmarks (H100)
|
||||
|
||||
| Model | GPUs | Parallelism | TPS/GPU | Techniques |
|
||||
|-------|------|-------------|---------|------------|
|
||||
| Llama 8B | 8 | FSDP | 5,762 | Baseline |
|
||||
| Llama 8B | 8 | FSDP+compile+FP8 | 8,532 | +48% |
|
||||
| Llama 70B | 256 | FSDP+TP+AsyncTP | 876 | 2D parallel |
|
||||
| Llama 405B | 512 | FSDP+TP+PP | 128 | 3D parallel |
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**FSDP2 configuration**: See [references/fsdp.md](references/fsdp.md) for detailed FSDP2 vs FSDP1 comparison and ZeRO equivalents.
|
||||
|
||||
**Float8 training**: See [references/float8.md](references/float8.md) for tensorwise vs rowwise scaling recipes.
|
||||
|
||||
**Checkpointing**: See [references/checkpoint.md](references/checkpoint.md) for HuggingFace conversion and async checkpointing.
|
||||
|
||||
**Adding custom models**: See [references/custom-models.md](references/custom-models.md) for TrainSpec protocol.
|
||||
|
||||
## Resources
|
||||
|
||||
- GitHub: https://github.com/pytorch/torchtitan
|
||||
- Paper: https://arxiv.org/abs/2410.06511
|
||||
- ICLR 2025: https://iclr.cc/virtual/2025/poster/29620
|
||||
- PyTorch Forum: https://discuss.pytorch.org/c/distributed/torchtitan/44
|
||||
|
||||
@@ -0,0 +1,181 @@
|
||||
# Checkpointing in TorchTitan
|
||||
|
||||
TorchTitan uses PyTorch Distributed Checkpoint (DCP) for fault-tolerant, interoperable checkpointing.
|
||||
|
||||
## Basic Configuration
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
enable = true
|
||||
folder = "checkpoint"
|
||||
interval = 500
|
||||
```
|
||||
|
||||
## Save Model Only (Smaller Checkpoints)
|
||||
|
||||
Exclude optimizer state and training metadata:
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
enable = true
|
||||
last_save_model_only = true
|
||||
export_dtype = "bfloat16" # Optional: export in lower precision
|
||||
```
|
||||
|
||||
## Excluding Keys from Loading
|
||||
|
||||
Partial checkpoint loading for modified settings:
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
enable = true
|
||||
exclude_from_loading = ["data_loader", "lr_scheduler"]
|
||||
```
|
||||
|
||||
CLI equivalent:
|
||||
```bash
|
||||
--checkpoint.exclude_from_loading data_loader,lr_scheduler
|
||||
```
|
||||
|
||||
## Creating Seed Checkpoints
|
||||
|
||||
Required for Pipeline Parallelism to ensure consistent initialization:
|
||||
|
||||
```bash
|
||||
NGPU=1 CONFIG_FILE=<path_to_config> ./run_train.sh \
|
||||
--checkpoint.enable \
|
||||
--checkpoint.create_seed_checkpoint \
|
||||
--parallelism.data_parallel_replicate_degree 1 \
|
||||
--parallelism.data_parallel_shard_degree 1 \
|
||||
--parallelism.tensor_parallel_degree 1 \
|
||||
--parallelism.pipeline_parallel_degree 1 \
|
||||
--parallelism.context_parallel_degree 1 \
|
||||
--parallelism.expert_parallel_degree 1
|
||||
```
|
||||
|
||||
This initializes on single CPU for reproducible initialization across any GPU count.
|
||||
|
||||
## Async Checkpointing
|
||||
|
||||
Reduce checkpoint overhead with async writes:
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
enable = true
|
||||
async_mode = "async" # Options: "disabled", "async", "async_with_pinned_mem"
|
||||
```
|
||||
|
||||
## HuggingFace Conversion
|
||||
|
||||
### During Training
|
||||
|
||||
Save directly in HuggingFace format:
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
last_save_in_hf = true
|
||||
last_save_model_only = true
|
||||
```
|
||||
|
||||
Load from HuggingFace:
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
initial_load_in_hf = true
|
||||
|
||||
[model]
|
||||
hf_assets_path = "./path/to/hf/checkpoint"
|
||||
```
|
||||
|
||||
### Offline Conversion
|
||||
|
||||
Convert without running training:
|
||||
|
||||
```bash
|
||||
# HuggingFace -> TorchTitan
|
||||
python ./scripts/checkpoint_conversion/convert_from_hf.py \
|
||||
<input_dir> <output_dir> \
|
||||
--model_name llama3 \
|
||||
--model_flavor 8B
|
||||
|
||||
# TorchTitan -> HuggingFace
|
||||
python ./scripts/checkpoint_conversion/convert_to_hf.py \
|
||||
<input_dir> <output_dir> \
|
||||
--hf_assets_path ./assets/hf/Llama3.1-8B \
|
||||
--model_name llama3 \
|
||||
--model_flavor 8B
|
||||
```
|
||||
|
||||
### Example
|
||||
|
||||
```bash
|
||||
python ./scripts/convert_from_hf.py \
|
||||
~/.cache/huggingface/hub/models--meta-llama--Meta-Llama-3-8B/snapshots/8cde5ca8380496c9a6cc7ef3a8b46a0372a1d920/ \
|
||||
./initial_load_path/ \
|
||||
--model_name llama3 \
|
||||
--model_flavor 8B
|
||||
```
|
||||
|
||||
## Converting to Single .pt File
|
||||
|
||||
Convert DCP sharded checkpoint to single PyTorch file:
|
||||
|
||||
```bash
|
||||
python -m torch.distributed.checkpoint.format_utils \
|
||||
dcp_to_torch \
|
||||
torchtitan/outputs/checkpoint/step-1000 \
|
||||
checkpoint.pt
|
||||
```
|
||||
|
||||
## Checkpoint Structure
|
||||
|
||||
DCP saves sharded checkpoints that can be resharded for different parallelism configurations:
|
||||
|
||||
```
|
||||
checkpoint/
|
||||
├── step-500/
|
||||
│ ├── .metadata
|
||||
│ ├── __0_0.distcp
|
||||
│ ├── __0_1.distcp
|
||||
│ └── ...
|
||||
└── step-1000/
|
||||
└── ...
|
||||
```
|
||||
|
||||
## Resume Training
|
||||
|
||||
Training auto-resumes from the latest checkpoint in the configured folder. To resume from a specific step:
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
load_step = 500 # Resume from step 500
|
||||
```
|
||||
|
||||
## Interoperability with TorchTune
|
||||
|
||||
Checkpoints saved with `last_save_model_only = true` can be loaded directly into [torchtune](https://github.com/pytorch/torchtune) for fine-tuning.
|
||||
|
||||
## Full Configuration Example
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
enable = true
|
||||
folder = "checkpoint"
|
||||
interval = 500
|
||||
load_step = -1 # -1 = latest, or specify step number
|
||||
last_save_model_only = true
|
||||
export_dtype = "bfloat16"
|
||||
async_mode = "async"
|
||||
exclude_from_loading = []
|
||||
last_save_in_hf = false
|
||||
initial_load_in_hf = false
|
||||
create_seed_checkpoint = false
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Large models**: Use `async_mode = "async"` to overlap checkpoint saves with training
|
||||
2. **Fine-tuning export**: Enable `last_save_model_only` and `export_dtype = "bfloat16"` for smaller files
|
||||
3. **Pipeline parallelism**: Always create seed checkpoint first
|
||||
4. **Debugging**: Save frequent checkpoints during development, reduce for production
|
||||
5. **HF interop**: Use conversion scripts for offline conversion, direct save/load for training workflows
|
||||
@@ -0,0 +1,258 @@
|
||||
# Adding Custom Models to TorchTitan
|
||||
|
||||
This guide explains how to add a new model to TorchTitan following the established patterns.
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
torchtitan/models/your_model/
|
||||
├── model/
|
||||
│ ├── __init__.py
|
||||
│ ├── args.py # Model arguments
|
||||
│ ├── model.py # Model definition
|
||||
│ └── state_dict_adapter.py # HF conversion (optional)
|
||||
├── infra/
|
||||
│ ├── __init__.py
|
||||
│ ├── parallelize.py # TP, FSDP, compile application
|
||||
│ └── pipeline.py # PP application (optional)
|
||||
├── train_configs/
|
||||
│ ├── debug_model.toml
|
||||
│ └── your_model_XB.toml
|
||||
├── __init__.py # TrainSpec registration
|
||||
└── README.md
|
||||
```
|
||||
|
||||
## Step 1: Define Model Arguments
|
||||
|
||||
Inherit from `BaseModelArgs`:
|
||||
|
||||
```python
|
||||
# model/args.py
|
||||
from torchtitan.protocols.model import BaseModelArgs
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class YourModelArgs(BaseModelArgs):
|
||||
dim: int = 4096
|
||||
n_layers: int = 32
|
||||
n_heads: int = 32
|
||||
vocab_size: int = 128256
|
||||
|
||||
def get_nparams_and_flops(self, seq_len: int) -> tuple[int, int]:
|
||||
"""Return (num_params, flops_per_token) for throughput calculation."""
|
||||
nparams = self.vocab_size * self.dim + ... # Calculate params
|
||||
flops = 6 * nparams # Approximate: 6 * params for forward+backward
|
||||
return nparams, flops
|
||||
|
||||
def update_from_config(self, job_config) -> "YourModelArgs":
|
||||
"""Update args from training config."""
|
||||
# Override specific args from job_config if needed
|
||||
return self
|
||||
```
|
||||
|
||||
## Step 2: Define Model
|
||||
|
||||
Inherit from `ModelProtocol`:
|
||||
|
||||
```python
|
||||
# model/model.py
|
||||
import torch.nn as nn
|
||||
from torchtitan.protocols.model import ModelProtocol
|
||||
from .args import YourModelArgs
|
||||
|
||||
class YourModel(ModelProtocol):
|
||||
def __init__(self, args: YourModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
|
||||
self.layers = nn.ModuleDict({
|
||||
str(i): TransformerBlock(args) for i in range(args.n_layers)
|
||||
})
|
||||
self.norm = RMSNorm(args.dim)
|
||||
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
|
||||
|
||||
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
h = self.tok_embeddings(tokens)
|
||||
for layer in self.layers.values():
|
||||
h = layer(h)
|
||||
h = self.norm(h)
|
||||
return self.output(h)
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize weights recursively."""
|
||||
for module in self.modules():
|
||||
if hasattr(module, 'init_weights') and module is not self:
|
||||
module.init_weights()
|
||||
elif isinstance(module, nn.Linear):
|
||||
nn.init.normal_(module.weight, std=0.02)
|
||||
```
|
||||
|
||||
**Important guidelines**:
|
||||
- Write single-device model code (parallelism applied externally)
|
||||
- Use `nn.ModuleDict` for layers (preserves FQNs when deleting for PP)
|
||||
- Make input/output layers optional for PP compatibility
|
||||
- Define `init_weights()` recursively
|
||||
|
||||
## Step 3: Parallelize Function
|
||||
|
||||
```python
|
||||
# infra/parallelize.py
|
||||
from torch.distributed._composable.fsdp import fully_shard
|
||||
from torch.distributed.tensor.parallel import parallelize_module
|
||||
|
||||
def parallelize_your_model(
|
||||
model: YourModel,
|
||||
world_mesh: DeviceMesh,
|
||||
parallel_dims: ParallelDims,
|
||||
job_config: JobConfig,
|
||||
):
|
||||
# Apply in this order: TP -> AC -> compile -> FSDP
|
||||
|
||||
# 1. Tensor Parallelism
|
||||
if parallel_dims.tp_enabled:
|
||||
apply_tp(model, world_mesh["tp"], job_config)
|
||||
|
||||
# 2. Activation Checkpointing
|
||||
if job_config.activation_checkpoint.mode == "full":
|
||||
apply_ac(model, job_config)
|
||||
|
||||
# 3. torch.compile
|
||||
if job_config.compile.enable:
|
||||
model = torch.compile(model)
|
||||
|
||||
# 4. FSDP
|
||||
if parallel_dims.dp_enabled:
|
||||
apply_fsdp(model, world_mesh["dp"], job_config)
|
||||
|
||||
return model
|
||||
```
|
||||
|
||||
## Step 4: Create TrainSpec
|
||||
|
||||
```python
|
||||
# __init__.py
|
||||
from torchtitan.protocols.train_spec import TrainSpec, register_train_spec
|
||||
from .model.model import YourModel
|
||||
from .model.args import YourModelArgs
|
||||
from .infra.parallelize import parallelize_your_model
|
||||
|
||||
MODEL_CONFIGS = {
|
||||
"8B": YourModelArgs(dim=4096, n_layers=32, n_heads=32),
|
||||
"70B": YourModelArgs(dim=8192, n_layers=80, n_heads=64),
|
||||
}
|
||||
|
||||
def get_train_spec(flavor: str) -> TrainSpec:
|
||||
return TrainSpec(
|
||||
model_cls=YourModel,
|
||||
model_args=MODEL_CONFIGS[flavor],
|
||||
parallelize_fn=parallelize_your_model,
|
||||
pipeline_fn=None, # Or your_pipeline_fn for PP
|
||||
build_optimizer_fn=build_optimizer, # Reuse existing
|
||||
build_lr_scheduler_fn=build_lr_scheduler, # Reuse existing
|
||||
build_dataloader_fn=build_dataloader, # Reuse existing
|
||||
build_tokenizer_fn=build_tokenizer, # Reuse existing
|
||||
build_loss_fn=build_loss, # Reuse existing
|
||||
state_dict_adapter=None, # Or YourStateDictAdapter
|
||||
)
|
||||
|
||||
# Register so train.py can find it
|
||||
register_train_spec("your_model", get_train_spec)
|
||||
```
|
||||
|
||||
## Step 5: State Dict Adapter (Optional)
|
||||
|
||||
For HuggingFace checkpoint conversion:
|
||||
|
||||
```python
|
||||
# model/state_dict_adapter.py
|
||||
from torchtitan.protocols.state_dict_adapter import BaseStateDictAdapter
|
||||
|
||||
class YourStateDictAdapter(BaseStateDictAdapter):
|
||||
def to_hf(self, state_dict: dict) -> dict:
|
||||
"""Convert torchtitan state dict to HF format."""
|
||||
hf_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
hf_key = self._convert_key_to_hf(key)
|
||||
hf_state_dict[hf_key] = value
|
||||
return hf_state_dict
|
||||
|
||||
def from_hf(self, state_dict: dict) -> dict:
|
||||
"""Convert HF state dict to torchtitan format."""
|
||||
tt_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
tt_key = self._convert_key_from_hf(key)
|
||||
tt_state_dict[tt_key] = value
|
||||
return tt_state_dict
|
||||
```
|
||||
|
||||
## Step 6: Training Config
|
||||
|
||||
```toml
|
||||
# train_configs/your_model_8b.toml
|
||||
[job]
|
||||
dump_folder = "./outputs"
|
||||
description = "Your Model 8B training"
|
||||
|
||||
[model]
|
||||
name = "your_model"
|
||||
flavor = "8B"
|
||||
|
||||
[optimizer]
|
||||
name = "AdamW"
|
||||
lr = 3e-4
|
||||
|
||||
[training]
|
||||
local_batch_size = 2
|
||||
seq_len = 8192
|
||||
steps = 1000
|
||||
dataset = "c4"
|
||||
|
||||
[parallelism]
|
||||
data_parallel_shard_degree = -1
|
||||
tensor_parallel_degree = 1
|
||||
```
|
||||
|
||||
## Step 7: Register Model
|
||||
|
||||
Add to `torchtitan/models/__init__.py`:
|
||||
|
||||
```python
|
||||
from .your_model import get_train_spec as get_your_model_train_spec
|
||||
|
||||
MODEL_REGISTRY["your_model"] = get_your_model_train_spec
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
### Numerics Test
|
||||
|
||||
Compare output with HuggingFace implementation:
|
||||
|
||||
```python
|
||||
def test_numerics():
|
||||
# Load same checkpoint into both implementations
|
||||
tt_model = YourModel(args).load_checkpoint(...)
|
||||
hf_model = HFYourModel.from_pretrained(...)
|
||||
|
||||
# Compare outputs
|
||||
input_ids = torch.randint(0, vocab_size, (1, 128))
|
||||
tt_output = tt_model(input_ids)
|
||||
hf_output = hf_model(input_ids).logits
|
||||
|
||||
torch.testing.assert_close(tt_output, hf_output, atol=1e-4, rtol=1e-4)
|
||||
```
|
||||
|
||||
### Loss Convergence
|
||||
|
||||
Compare loss curves with verified baseline (see `docs/converging.md`).
|
||||
|
||||
### Performance Benchmark
|
||||
|
||||
Add benchmark config to `benchmarks/` folder.
|
||||
|
||||
## Guiding Principles
|
||||
|
||||
1. **Readability over flexibility**: Don't over-abstract
|
||||
2. **Minimal model changes**: Parallelism applied externally
|
||||
3. **Clean, minimal codebase**: Reuse existing components where possible
|
||||
4. **Single-device semantics**: Model code should work on single GPU
|
||||
@@ -0,0 +1,133 @@
|
||||
# Float8 Training in TorchTitan
|
||||
|
||||
Float8 training provides substantial speedups for models where GEMMs are large enough that the FP8 tensorcore speedup outweighs dynamic quantization overhead.
|
||||
|
||||
## Hardware Requirements
|
||||
|
||||
- NVIDIA H100 or newer GPUs (FP8 Tensor Cores)
|
||||
- Blackwell GPUs for MXFP8 training
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
USE_CPP=0 pip install git+https://github.com/pytorch/ao.git
|
||||
```
|
||||
|
||||
## Usage: Tensorwise Scaling
|
||||
|
||||
Standard Float8 with tensorwise dynamic scaling:
|
||||
|
||||
```bash
|
||||
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh \
|
||||
--model.converters="quantize.linear.float8" \
|
||||
--quantize.linear.float8.enable_fsdp_float8_all_gather \
|
||||
--quantize.linear.float8.precompute_float8_dynamic_scale_for_fsdp \
|
||||
--compile.enable
|
||||
```
|
||||
|
||||
### Key Arguments
|
||||
|
||||
| Argument | Description |
|
||||
|----------|-------------|
|
||||
| `--model.converters="quantize.linear.float8"` | Swap `nn.Linear` with `Float8Linear` |
|
||||
| `--quantize.linear.float8.enable_fsdp_float8_all_gather` | Communicate in float8 to save bandwidth |
|
||||
| `--quantize.linear.float8.precompute_float8_dynamic_scale_for_fsdp` | Single all-reduce for all AMAX/scales |
|
||||
| `--compile.enable` | Required - fuses float8 scaling/casting kernels |
|
||||
|
||||
## Usage: Rowwise Scaling
|
||||
|
||||
Higher accuracy than tensorwise scaling:
|
||||
|
||||
```bash
|
||||
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh \
|
||||
--model.converters="quantize.linear.float8" \
|
||||
--quantize.linear.float8.recipe_name rowwise \
|
||||
--compile.enable
|
||||
```
|
||||
|
||||
## Filtering Layers
|
||||
|
||||
Not all layers benefit from Float8. Filter small layers:
|
||||
|
||||
```bash
|
||||
--quantize.linear.float8.filter_fqns="attention.wk,attention.wv,output"
|
||||
```
|
||||
|
||||
### Auto-filtering
|
||||
|
||||
Automatically skip layers too small to benefit:
|
||||
|
||||
```bash
|
||||
--quantize.linear.float8.filter_fqns="auto_filter_small_kn"
|
||||
```
|
||||
|
||||
Thresholds based on H100 microbenchmarks where speedup > overhead.
|
||||
|
||||
## TOML Configuration
|
||||
|
||||
```toml
|
||||
[model]
|
||||
converters = ["quantize.linear.float8"]
|
||||
|
||||
[quantize.linear.float8]
|
||||
enable_fsdp_float8_all_gather = true
|
||||
precompute_float8_dynamic_scale_for_fsdp = true
|
||||
filter_fqns = ["output", "auto_filter_small_kn"]
|
||||
|
||||
[compile]
|
||||
enable = true
|
||||
components = ["model", "loss"]
|
||||
```
|
||||
|
||||
## How Float8 Works with Distributed Training
|
||||
|
||||
### Single Device
|
||||
|
||||
Cast input and weight to float8 inside forward before calling `torch._scaled_mm`:
|
||||
|
||||
```python
|
||||
# Float8 matmul requires scales
|
||||
torch._scaled_mm(input_fp8, weight_fp8, scale_a=scale_input, scale_b=scale_weight)
|
||||
```
|
||||
|
||||
### FSDP + Float8
|
||||
|
||||
1. Cast sharded high-precision weights (1/N per rank) to float8
|
||||
2. Perform float8 all-gather (saves bandwidth vs bf16/fp32)
|
||||
3. Communicate `max(abs)` across ranks for scale computation
|
||||
4. At forward start, have unsharded float8 weights ready
|
||||
|
||||
**Net benefit**: Float8 all-gather + amax communication can beat bf16/fp32 all-gather, depending on world size and message size.
|
||||
|
||||
### TP + Float8
|
||||
|
||||
- **Input**: Cast sharded input to float8, all-gather in float8
|
||||
- **Weights**: Communicate `max(abs)` for sharded weights
|
||||
- **Matmul**: Float8 input (unsharded) x float8 weight (sharded) with global scales
|
||||
|
||||
## Scaling Strategies
|
||||
|
||||
| Strategy | Status | Description |
|
||||
|----------|--------|-------------|
|
||||
| Tensorwise dynamic | Stable | Single scale per tensor |
|
||||
| Rowwise dynamic | Alpha | Scale per row, higher accuracy |
|
||||
|
||||
## Performance Gains
|
||||
|
||||
From benchmarks on H100:
|
||||
|
||||
| Configuration | TPS/GPU | vs Baseline |
|
||||
|---------------|---------|-------------|
|
||||
| FSDP only | 5,762 | - |
|
||||
| FSDP + compile | 6,667 | +16% |
|
||||
| FSDP + compile + Float8 | 8,532 | +48% |
|
||||
|
||||
## Determining Float8 Benefit
|
||||
|
||||
Check [torchao microbenchmarks](https://github.com/pytorch/ao/tree/main/torchao/float8#performance) for forward+backward pass speedups on "layer norm => linear => sigmoid" for different M,N,K sizes.
|
||||
|
||||
Rule of thumb: GEMMs with K,N > 4096 typically benefit from Float8.
|
||||
|
||||
## MXFP8 Training (Blackwell)
|
||||
|
||||
For NVIDIA Blackwell GPUs, TorchTitan supports MXFP8 (Microscaling FP8) for both dense and MoE models. See [docs/mxfp8.md](https://github.com/pytorch/torchtitan/blob/main/docs/mxfp8.md) for details.
|
||||
@@ -0,0 +1,126 @@
|
||||
# FSDP2 in TorchTitan
|
||||
|
||||
## Why FSDP2?
|
||||
|
||||
FSDP2 is a rewrite of PyTorch's Fully Sharded Data Parallel (FSDP) API, removing the `FlatParameter` abstraction for better composability and simpler implementation.
|
||||
|
||||
### Key improvements over FSDP1
|
||||
|
||||
- **DTensor-based sharding**: Sharded parameters are `DTensor`s on dim-0, enabling easy manipulation and communication-free sharded state dicts
|
||||
- **Better memory management**: Deterministic and lower GPU memory (7% reduction) by avoiding `recordStream`
|
||||
- **Simplified API**: Fewer arguments, no wrapper class
|
||||
|
||||
### Performance
|
||||
|
||||
On Llama-7B with 8x H100s, FSDP2 achieves higher MFU with 7% lower peak memory than FSDP1, matching the same loss curve.
|
||||
|
||||
## API Reference
|
||||
|
||||
```python
|
||||
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy, OffloadPolicy
|
||||
|
||||
@contract(state_cls=FSDPState)
|
||||
def fully_shard(
|
||||
module: nn.Module,
|
||||
*,
|
||||
mesh: Optional[DeviceMesh] = None,
|
||||
reshard_after_forward: Union[bool, int] = True,
|
||||
mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(),
|
||||
offload_policy: OffloadPolicy = OffloadPolicy(),
|
||||
) -> nn.Module:
|
||||
```
|
||||
|
||||
## Sharding Strategies (ZeRO Equivalents)
|
||||
|
||||
| FSDP2 Configuration | FSDP1 Equivalent | DeepSpeed |
|
||||
|---------------------|------------------|-----------|
|
||||
| 1D mesh + `reshard_after_forward=True` | FULL_SHARD | ZeRO-3 |
|
||||
| 1D mesh + `reshard_after_forward=False` | SHARD_GRAD_OP | ZeRO-2 |
|
||||
| 2D mesh + `reshard_after_forward=True` | HYBRID_SHARD | MiCS |
|
||||
| 1D/2D mesh + `reshard_after_forward=8` (int) | - | ZeRO++ hpZ |
|
||||
|
||||
## Meta-Device Initialization
|
||||
|
||||
FSDP2 supports materializing tensors onto GPU _after_ sharding:
|
||||
|
||||
```python
|
||||
# Initialize on meta device (no memory)
|
||||
with torch.device("meta"):
|
||||
model = Transformer()
|
||||
|
||||
# Apply FSDP2 sharding
|
||||
for module in model.modules():
|
||||
if isinstance(module, TransformerBlock):
|
||||
fully_shard(module)
|
||||
fully_shard(model)
|
||||
|
||||
# Parameters still on meta device
|
||||
for tensor in itertools.chain(model.parameters(), model.buffers()):
|
||||
assert tensor.device == torch.device("meta")
|
||||
|
||||
# Allocate sharded parameters on GPU
|
||||
model.to_empty(device="cuda")
|
||||
|
||||
# Initialize weights
|
||||
model.init_weights()
|
||||
```
|
||||
|
||||
## State Dict Differences
|
||||
|
||||
| Operation | FSDP1 | FSDP2 |
|
||||
|-----------|-------|-------|
|
||||
| `model.state_dict()` | Full state dict | Sharded state dict (no communication) |
|
||||
| `optim.state_dict()` | Local state dict | Sharded state dict (no communication) |
|
||||
| `summon_full_params()` | Supported | Use `DTensor` APIs like `full_tensor()` |
|
||||
| Gradient clipping | `FSDP.clip_grad_norm_()` | `nn.utils.clip_grad_norm_()` |
|
||||
|
||||
## Mixed Precision
|
||||
|
||||
```python
|
||||
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
|
||||
|
||||
mp_policy = MixedPrecisionPolicy(
|
||||
param_dtype=torch.bfloat16,
|
||||
reduce_dtype=torch.float32,
|
||||
output_dtype=torch.bfloat16,
|
||||
cast_forward_inputs=True,
|
||||
)
|
||||
|
||||
fully_shard(model, mp_policy=mp_policy)
|
||||
```
|
||||
|
||||
## HSDP (Hybrid Sharded Data Parallel)
|
||||
|
||||
For 2D parallelism with replication + sharding:
|
||||
|
||||
```python
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
|
||||
# Replicate across 4 groups, shard within 8 GPUs each
|
||||
mesh = init_device_mesh("cuda", (4, 8), mesh_dim_names=("replicate", "shard"))
|
||||
|
||||
fully_shard(model, mesh=mesh)
|
||||
```
|
||||
|
||||
## Configuration in TorchTitan
|
||||
|
||||
```toml
|
||||
[parallelism]
|
||||
# FSDP sharding degree (-1 = auto, use all available GPUs)
|
||||
data_parallel_shard_degree = -1
|
||||
|
||||
# HSDP replication degree (1 = pure FSDP, >1 = HSDP)
|
||||
data_parallel_replicate_degree = 1
|
||||
```
|
||||
|
||||
## Removed Arguments from FSDP1
|
||||
|
||||
These FSDP1 arguments are no longer needed:
|
||||
|
||||
- `auto_wrap_policy`: Apply `fully_shard` directly to modules
|
||||
- `backward_prefetch`: Always uses BACKWARD_PRE
|
||||
- `param_init_fn`: Use meta-device initialization
|
||||
- `device_id`: Uses mesh's device automatically
|
||||
- `sync_module_states`: Not needed with DTensor
|
||||
- `limit_all_gathers`: New memory management doesn't need it
|
||||
- `use_orig_params`: Always true (no FlatParameter)
|
||||
@@ -0,0 +1,5 @@
|
||||
# Skills Coming Soon
|
||||
|
||||
This directory will contain high-quality AI research skills for tokenization.
|
||||
|
||||
See [CONTRIBUTING.md](../CONTRIBUTING.md) for how to contribute.
|
||||
@@ -0,0 +1,516 @@
|
||||
---
|
||||
name: huggingface-tokenizers
|
||||
description: Fast tokenizers optimized for research and production. Rust-based implementation tokenizes 1GB in <20 seconds. Supports BPE, WordPiece, and Unigram algorithms. Train custom vocabularies, track alignments, handle padding/truncation. Integrates seamlessly with transformers. Use when you need high-performance tokenization or custom tokenizer training.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Tokenization, HuggingFace, BPE, WordPiece, Unigram, Fast Tokenization, Rust, Custom Tokenizer, Alignment Tracking, Production]
|
||||
dependencies: [tokenizers, transformers, datasets]
|
||||
---
|
||||
|
||||
# HuggingFace Tokenizers - Fast Tokenization for NLP
|
||||
|
||||
Fast, production-ready tokenizers with Rust performance and Python ease-of-use.
|
||||
|
||||
## When to use HuggingFace Tokenizers
|
||||
|
||||
**Use HuggingFace Tokenizers when:**
|
||||
- Need extremely fast tokenization (<20s per GB of text)
|
||||
- Training custom tokenizers from scratch
|
||||
- Want alignment tracking (token → original text position)
|
||||
- Building production NLP pipelines
|
||||
- Need to tokenize large corpora efficiently
|
||||
|
||||
**Performance**:
|
||||
- **Speed**: <20 seconds to tokenize 1GB on CPU
|
||||
- **Implementation**: Rust core with Python/Node.js bindings
|
||||
- **Efficiency**: 10-100× faster than pure Python implementations
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **SentencePiece**: Language-independent, used by T5/ALBERT
|
||||
- **tiktoken**: OpenAI's BPE tokenizer for GPT models
|
||||
- **transformers AutoTokenizer**: Loading pretrained only (uses this library internally)
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Install tokenizers
|
||||
pip install tokenizers
|
||||
|
||||
# With transformers integration
|
||||
pip install tokenizers transformers
|
||||
```
|
||||
|
||||
### Load pretrained tokenizer
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
|
||||
# Load from HuggingFace Hub
|
||||
tokenizer = Tokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
# Encode text
|
||||
output = tokenizer.encode("Hello, how are you?")
|
||||
print(output.tokens) # ['hello', ',', 'how', 'are', 'you', '?']
|
||||
print(output.ids) # [7592, 1010, 2129, 2024, 2017, 1029]
|
||||
|
||||
# Decode back
|
||||
text = tokenizer.decode(output.ids)
|
||||
print(text) # "hello, how are you?"
|
||||
```
|
||||
|
||||
### Train custom BPE tokenizer
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
from tokenizers.pre_tokenizers import Whitespace
|
||||
|
||||
# Initialize tokenizer with BPE model
|
||||
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
|
||||
tokenizer.pre_tokenizer = Whitespace()
|
||||
|
||||
# Configure trainer
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=30000,
|
||||
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
|
||||
min_frequency=2
|
||||
)
|
||||
|
||||
# Train on files
|
||||
files = ["train.txt", "validation.txt"]
|
||||
tokenizer.train(files, trainer)
|
||||
|
||||
# Save
|
||||
tokenizer.save("my-tokenizer.json")
|
||||
```
|
||||
|
||||
**Training time**: ~1-2 minutes for 100MB corpus, ~10-20 minutes for 1GB
|
||||
|
||||
### Batch encoding with padding
|
||||
|
||||
```python
|
||||
# Enable padding
|
||||
tokenizer.enable_padding(pad_id=3, pad_token="[PAD]")
|
||||
|
||||
# Encode batch
|
||||
texts = ["Hello world", "This is a longer sentence"]
|
||||
encodings = tokenizer.encode_batch(texts)
|
||||
|
||||
for encoding in encodings:
|
||||
print(encoding.ids)
|
||||
# [101, 7592, 2088, 102, 3, 3, 3]
|
||||
# [101, 2023, 2003, 1037, 2936, 6251, 102]
|
||||
```
|
||||
|
||||
## Tokenization algorithms
|
||||
|
||||
### BPE (Byte-Pair Encoding)
|
||||
|
||||
**How it works**:
|
||||
1. Start with character-level vocabulary
|
||||
2. Find most frequent character pair
|
||||
3. Merge into new token, add to vocabulary
|
||||
4. Repeat until vocabulary size reached
|
||||
|
||||
**Used by**: GPT-2, GPT-3, RoBERTa, BART, DeBERTa
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
from tokenizers.pre_tokenizers import ByteLevel
|
||||
|
||||
tokenizer = Tokenizer(BPE(unk_token="<|endoftext|>"))
|
||||
tokenizer.pre_tokenizer = ByteLevel()
|
||||
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=50257,
|
||||
special_tokens=["<|endoftext|>"],
|
||||
min_frequency=2
|
||||
)
|
||||
|
||||
tokenizer.train(files=["data.txt"], trainer=trainer)
|
||||
```
|
||||
|
||||
**Advantages**:
|
||||
- Handles OOV words well (breaks into subwords)
|
||||
- Flexible vocabulary size
|
||||
- Good for morphologically rich languages
|
||||
|
||||
**Trade-offs**:
|
||||
- Tokenization depends on merge order
|
||||
- May split common words unexpectedly
|
||||
|
||||
### WordPiece
|
||||
|
||||
**How it works**:
|
||||
1. Start with character vocabulary
|
||||
2. Score merge pairs: `frequency(pair) / (frequency(first) × frequency(second))`
|
||||
3. Merge highest scoring pair
|
||||
4. Repeat until vocabulary size reached
|
||||
|
||||
**Used by**: BERT, DistilBERT, MobileBERT
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import WordPiece
|
||||
from tokenizers.trainers import WordPieceTrainer
|
||||
from tokenizers.pre_tokenizers import Whitespace
|
||||
from tokenizers.normalizers import BertNormalizer
|
||||
|
||||
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
|
||||
tokenizer.normalizer = BertNormalizer(lowercase=True)
|
||||
tokenizer.pre_tokenizer = Whitespace()
|
||||
|
||||
trainer = WordPieceTrainer(
|
||||
vocab_size=30522,
|
||||
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
|
||||
continuing_subword_prefix="##"
|
||||
)
|
||||
|
||||
tokenizer.train(files=["corpus.txt"], trainer=trainer)
|
||||
```
|
||||
|
||||
**Advantages**:
|
||||
- Prioritizes meaningful merges (high score = semantically related)
|
||||
- Used successfully in BERT (state-of-the-art results)
|
||||
|
||||
**Trade-offs**:
|
||||
- Unknown words become `[UNK]` if no subword match
|
||||
- Saves vocabulary, not merge rules (larger files)
|
||||
|
||||
### Unigram
|
||||
|
||||
**How it works**:
|
||||
1. Start with large vocabulary (all substrings)
|
||||
2. Compute loss for corpus with current vocabulary
|
||||
3. Remove tokens with minimal impact on loss
|
||||
4. Repeat until vocabulary size reached
|
||||
|
||||
**Used by**: ALBERT, T5, mBART, XLNet (via SentencePiece)
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import Unigram
|
||||
from tokenizers.trainers import UnigramTrainer
|
||||
|
||||
tokenizer = Tokenizer(Unigram())
|
||||
|
||||
trainer = UnigramTrainer(
|
||||
vocab_size=8000,
|
||||
special_tokens=["<unk>", "<s>", "</s>"],
|
||||
unk_token="<unk>"
|
||||
)
|
||||
|
||||
tokenizer.train(files=["data.txt"], trainer=trainer)
|
||||
```
|
||||
|
||||
**Advantages**:
|
||||
- Probabilistic (finds most likely tokenization)
|
||||
- Works well for languages without word boundaries
|
||||
- Handles diverse linguistic contexts
|
||||
|
||||
**Trade-offs**:
|
||||
- Computationally expensive to train
|
||||
- More hyperparameters to tune
|
||||
|
||||
## Tokenization pipeline
|
||||
|
||||
Complete pipeline: **Normalization → Pre-tokenization → Model → Post-processing**
|
||||
|
||||
### Normalization
|
||||
|
||||
Clean and standardize text:
|
||||
|
||||
```python
|
||||
from tokenizers.normalizers import NFD, StripAccents, Lowercase, Sequence
|
||||
|
||||
tokenizer.normalizer = Sequence([
|
||||
NFD(), # Unicode normalization (decompose)
|
||||
Lowercase(), # Convert to lowercase
|
||||
StripAccents() # Remove accents
|
||||
])
|
||||
|
||||
# Input: "Héllo WORLD"
|
||||
# After normalization: "hello world"
|
||||
```
|
||||
|
||||
**Common normalizers**:
|
||||
- `NFD`, `NFC`, `NFKD`, `NFKC` - Unicode normalization forms
|
||||
- `Lowercase()` - Convert to lowercase
|
||||
- `StripAccents()` - Remove accents (é → e)
|
||||
- `Strip()` - Remove whitespace
|
||||
- `Replace(pattern, content)` - Regex replacement
|
||||
|
||||
### Pre-tokenization
|
||||
|
||||
Split text into word-like units:
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import Whitespace, Punctuation, Sequence, ByteLevel
|
||||
|
||||
# Split on whitespace and punctuation
|
||||
tokenizer.pre_tokenizer = Sequence([
|
||||
Whitespace(),
|
||||
Punctuation()
|
||||
])
|
||||
|
||||
# Input: "Hello, world!"
|
||||
# After pre-tokenization: ["Hello", ",", "world", "!"]
|
||||
```
|
||||
|
||||
**Common pre-tokenizers**:
|
||||
- `Whitespace()` - Split on spaces, tabs, newlines
|
||||
- `ByteLevel()` - GPT-2 style byte-level splitting
|
||||
- `Punctuation()` - Isolate punctuation
|
||||
- `Digits(individual_digits=True)` - Split digits individually
|
||||
- `Metaspace()` - Replace spaces with ▁ (SentencePiece style)
|
||||
|
||||
### Post-processing
|
||||
|
||||
Add special tokens for model input:
|
||||
|
||||
```python
|
||||
from tokenizers.processors import TemplateProcessing
|
||||
|
||||
# BERT-style: [CLS] sentence [SEP]
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="[CLS] $A [SEP]",
|
||||
pair="[CLS] $A [SEP] $B [SEP]",
|
||||
special_tokens=[
|
||||
("[CLS]", 1),
|
||||
("[SEP]", 2),
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
**Common patterns**:
|
||||
```python
|
||||
# GPT-2: sentence <|endoftext|>
|
||||
TemplateProcessing(
|
||||
single="$A <|endoftext|>",
|
||||
special_tokens=[("<|endoftext|>", 50256)]
|
||||
)
|
||||
|
||||
# RoBERTa: <s> sentence </s>
|
||||
TemplateProcessing(
|
||||
single="<s> $A </s>",
|
||||
pair="<s> $A </s> </s> $B </s>",
|
||||
special_tokens=[("<s>", 0), ("</s>", 2)]
|
||||
)
|
||||
```
|
||||
|
||||
## Alignment tracking
|
||||
|
||||
Track token positions in original text:
|
||||
|
||||
```python
|
||||
output = tokenizer.encode("Hello, world!")
|
||||
|
||||
# Get token offsets
|
||||
for token, offset in zip(output.tokens, output.offsets):
|
||||
start, end = offset
|
||||
print(f"{token:10} → [{start:2}, {end:2}): {text[start:end]!r}")
|
||||
|
||||
# Output:
|
||||
# hello → [ 0, 5): 'Hello'
|
||||
# , → [ 5, 6): ','
|
||||
# world → [ 7, 12): 'world'
|
||||
# ! → [12, 13): '!'
|
||||
```
|
||||
|
||||
**Use cases**:
|
||||
- Named entity recognition (map predictions back to text)
|
||||
- Question answering (extract answer spans)
|
||||
- Token classification (align labels to original positions)
|
||||
|
||||
## Integration with transformers
|
||||
|
||||
### Load with AutoTokenizer
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# AutoTokenizer automatically uses fast tokenizers
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
# Check if using fast tokenizer
|
||||
print(tokenizer.is_fast) # True
|
||||
|
||||
# Access underlying tokenizers.Tokenizer
|
||||
fast_tokenizer = tokenizer.backend_tokenizer
|
||||
print(type(fast_tokenizer)) # <class 'tokenizers.Tokenizer'>
|
||||
```
|
||||
|
||||
### Convert custom tokenizer to transformers
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
# Train custom tokenizer
|
||||
tokenizer = Tokenizer(BPE())
|
||||
# ... train tokenizer ...
|
||||
tokenizer.save("my-tokenizer.json")
|
||||
|
||||
# Wrap for transformers
|
||||
transformers_tokenizer = PreTrainedTokenizerFast(
|
||||
tokenizer_file="my-tokenizer.json",
|
||||
unk_token="[UNK]",
|
||||
pad_token="[PAD]",
|
||||
cls_token="[CLS]",
|
||||
sep_token="[SEP]",
|
||||
mask_token="[MASK]"
|
||||
)
|
||||
|
||||
# Use like any transformers tokenizer
|
||||
outputs = transformers_tokenizer(
|
||||
"Hello world",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512,
|
||||
return_tensors="pt"
|
||||
)
|
||||
```
|
||||
|
||||
## Common patterns
|
||||
|
||||
### Train from iterator (large datasets)
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
# Load dataset
|
||||
dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="train")
|
||||
|
||||
# Create batch iterator
|
||||
def batch_iterator(batch_size=1000):
|
||||
for i in range(0, len(dataset), batch_size):
|
||||
yield dataset[i:i + batch_size]["text"]
|
||||
|
||||
# Train tokenizer
|
||||
tokenizer.train_from_iterator(
|
||||
batch_iterator(),
|
||||
trainer=trainer,
|
||||
length=len(dataset) # For progress bar
|
||||
)
|
||||
```
|
||||
|
||||
**Performance**: Processes 1GB in ~10-20 minutes
|
||||
|
||||
### Enable truncation and padding
|
||||
|
||||
```python
|
||||
# Enable truncation
|
||||
tokenizer.enable_truncation(max_length=512)
|
||||
|
||||
# Enable padding
|
||||
tokenizer.enable_padding(
|
||||
pad_id=tokenizer.token_to_id("[PAD]"),
|
||||
pad_token="[PAD]",
|
||||
length=512 # Fixed length, or None for batch max
|
||||
)
|
||||
|
||||
# Encode with both
|
||||
output = tokenizer.encode("This is a long sentence that will be truncated...")
|
||||
print(len(output.ids)) # 512
|
||||
```
|
||||
|
||||
### Multi-processing
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from multiprocessing import Pool
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = Tokenizer.from_file("tokenizer.json")
|
||||
|
||||
def encode_batch(texts):
|
||||
return tokenizer.encode_batch(texts)
|
||||
|
||||
# Process large corpus in parallel
|
||||
with Pool(8) as pool:
|
||||
# Split corpus into chunks
|
||||
chunk_size = 1000
|
||||
chunks = [corpus[i:i+chunk_size] for i in range(0, len(corpus), chunk_size)]
|
||||
|
||||
# Encode in parallel
|
||||
results = pool.map(encode_batch, chunks)
|
||||
```
|
||||
|
||||
**Speedup**: 5-8× with 8 cores
|
||||
|
||||
## Performance benchmarks
|
||||
|
||||
### Training speed
|
||||
|
||||
| Corpus Size | BPE (30k vocab) | WordPiece (30k) | Unigram (8k) |
|
||||
|-------------|-----------------|-----------------|--------------|
|
||||
| 10 MB | 15 sec | 18 sec | 25 sec |
|
||||
| 100 MB | 1.5 min | 2 min | 4 min |
|
||||
| 1 GB | 15 min | 20 min | 40 min |
|
||||
|
||||
**Hardware**: 16-core CPU, tested on English Wikipedia
|
||||
|
||||
### Tokenization speed
|
||||
|
||||
| Implementation | 1 GB corpus | Throughput |
|
||||
|----------------|-------------|---------------|
|
||||
| Pure Python | ~20 minutes | ~50 MB/min |
|
||||
| HF Tokenizers | ~15 seconds | ~4 GB/min |
|
||||
| **Speedup** | **80×** | **80×** |
|
||||
|
||||
**Test**: English text, average sentence length 20 words
|
||||
|
||||
### Memory usage
|
||||
|
||||
| Task | Memory |
|
||||
|-------------------------|---------|
|
||||
| Load tokenizer | ~10 MB |
|
||||
| Train BPE (30k vocab) | ~200 MB |
|
||||
| Encode 1M sentences | ~500 MB |
|
||||
|
||||
## Supported models
|
||||
|
||||
Pre-trained tokenizers available via `from_pretrained()`:
|
||||
|
||||
**BERT family**:
|
||||
- `bert-base-uncased`, `bert-large-cased`
|
||||
- `distilbert-base-uncased`
|
||||
- `roberta-base`, `roberta-large`
|
||||
|
||||
**GPT family**:
|
||||
- `gpt2`, `gpt2-medium`, `gpt2-large`
|
||||
- `distilgpt2`
|
||||
|
||||
**T5 family**:
|
||||
- `t5-small`, `t5-base`, `t5-large`
|
||||
- `google/flan-t5-xxl`
|
||||
|
||||
**Other**:
|
||||
- `facebook/bart-base`, `facebook/mbart-large-cc25`
|
||||
- `albert-base-v2`, `albert-xlarge-v2`
|
||||
- `xlm-roberta-base`, `xlm-roberta-large`
|
||||
|
||||
Browse all: https://huggingface.co/models?library=tokenizers
|
||||
|
||||
## References
|
||||
|
||||
- **[Training Guide](references/training.md)** - Train custom tokenizers, configure trainers, handle large datasets
|
||||
- **[Algorithms Deep Dive](references/algorithms.md)** - BPE, WordPiece, Unigram explained in detail
|
||||
- **[Pipeline Components](references/pipeline.md)** - Normalizers, pre-tokenizers, post-processors, decoders
|
||||
- **[Transformers Integration](references/integration.md)** - AutoTokenizer, PreTrainedTokenizerFast, special tokens
|
||||
|
||||
## Resources
|
||||
|
||||
- **Docs**: https://huggingface.co/docs/tokenizers
|
||||
- **GitHub**: https://github.com/huggingface/tokenizers ⭐ 9,000+
|
||||
- **Version**: 0.20.0+
|
||||
- **Course**: https://huggingface.co/learn/nlp-course/chapter6/1
|
||||
- **Paper**: BPE (Sennrich et al., 2016), WordPiece (Schuster & Nakajima, 2012)
|
||||
|
||||
|
||||
@@ -0,0 +1,653 @@
|
||||
# Tokenization Algorithms Deep Dive
|
||||
|
||||
Comprehensive explanation of BPE, WordPiece, and Unigram algorithms.
|
||||
|
||||
## Byte-Pair Encoding (BPE)
|
||||
|
||||
### Algorithm overview
|
||||
|
||||
BPE iteratively merges the most frequent pair of tokens in a corpus.
|
||||
|
||||
**Training process**:
|
||||
1. Initialize vocabulary with all characters
|
||||
2. Count frequency of all adjacent token pairs
|
||||
3. Merge most frequent pair into new token
|
||||
4. Add new token to vocabulary
|
||||
5. Update corpus with new token
|
||||
6. Repeat until vocabulary size reached
|
||||
|
||||
### Step-by-step example
|
||||
|
||||
**Corpus**:
|
||||
```
|
||||
low: 5
|
||||
lower: 2
|
||||
newest: 6
|
||||
widest: 3
|
||||
```
|
||||
|
||||
**Iteration 1**:
|
||||
```
|
||||
Count pairs:
|
||||
'e' + 's': 9 (newest: 6, widest: 3) ← most frequent
|
||||
'l' + 'o': 7
|
||||
'o' + 'w': 7
|
||||
...
|
||||
|
||||
Merge: 'e' + 's' → 'es'
|
||||
|
||||
Updated corpus:
|
||||
low: 5
|
||||
lower: 2
|
||||
newest: 6 → newes|t: 6
|
||||
widest: 3 → wides|t: 3
|
||||
|
||||
Vocabulary: [a-z] + ['es']
|
||||
```
|
||||
|
||||
**Iteration 2**:
|
||||
```
|
||||
Count pairs:
|
||||
'es' + 't': 9 ← most frequent
|
||||
'l' + 'o': 7
|
||||
...
|
||||
|
||||
Merge: 'es' + 't' → 'est'
|
||||
|
||||
Updated corpus:
|
||||
low: 5
|
||||
lower: 2
|
||||
newest: 6 → new|est: 6
|
||||
widest: 3 → wid|est: 3
|
||||
|
||||
Vocabulary: [a-z] + ['es', 'est']
|
||||
```
|
||||
|
||||
**Continue until desired vocabulary size...**
|
||||
|
||||
### Tokenization with trained BPE
|
||||
|
||||
Given vocabulary: `['l', 'o', 'w', 'e', 'r', 'n', 's', 't', 'i', 'd', 'es', 'est', 'lo', 'low', 'ne', 'new', 'newest', 'wi', 'wid', 'widest']`
|
||||
|
||||
Tokenize "lowest":
|
||||
```
|
||||
Step 1: Split into characters
|
||||
['l', 'o', 'w', 'e', 's', 't']
|
||||
|
||||
Step 2: Apply merges in order learned during training
|
||||
- Merge 'l' + 'o' → 'lo' (if this merge was learned)
|
||||
- Merge 'lo' + 'w' → 'low' (if learned)
|
||||
- Merge 'e' + 's' → 'es' (learned)
|
||||
- Merge 'es' + 't' → 'est' (learned)
|
||||
|
||||
Final: ['low', 'est']
|
||||
```
|
||||
|
||||
### Implementation
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
from tokenizers.pre_tokenizers import Whitespace
|
||||
|
||||
# Initialize
|
||||
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
|
||||
tokenizer.pre_tokenizer = Whitespace()
|
||||
|
||||
# Configure trainer
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=1000,
|
||||
min_frequency=2,
|
||||
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
|
||||
)
|
||||
|
||||
# Train
|
||||
corpus = [
|
||||
"This is a sample corpus for BPE training.",
|
||||
"BPE learns subword units from the training data.",
|
||||
# ... more sentences
|
||||
]
|
||||
|
||||
tokenizer.train_from_iterator(corpus, trainer=trainer)
|
||||
|
||||
# Use
|
||||
output = tokenizer.encode("This is tokenization")
|
||||
print(output.tokens) # ['This', 'is', 'token', 'ization']
|
||||
```
|
||||
|
||||
### Byte-level BPE (GPT-2 variant)
|
||||
|
||||
**Problem**: Standard BPE has limited character coverage (256+ Unicode chars)
|
||||
|
||||
**Solution**: Operate on byte level (256 bytes)
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import ByteLevel
|
||||
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
|
||||
|
||||
tokenizer = Tokenizer(BPE())
|
||||
|
||||
# Byte-level pre-tokenization
|
||||
tokenizer.pre_tokenizer = ByteLevel()
|
||||
tokenizer.decoder = ByteLevelDecoder()
|
||||
|
||||
# This handles ALL possible characters, including emojis
|
||||
text = "Hello 🌍 世界"
|
||||
tokens = tokenizer.encode(text).tokens
|
||||
```
|
||||
|
||||
**Advantages**:
|
||||
- Handles any Unicode character (256 byte coverage)
|
||||
- No unknown tokens (worst case: bytes)
|
||||
- Used by GPT-2, GPT-3, BART
|
||||
|
||||
**Trade-offs**:
|
||||
- Slightly worse compression (bytes vs characters)
|
||||
- More tokens for non-ASCII text
|
||||
|
||||
### BPE variants
|
||||
|
||||
**SentencePiece BPE**:
|
||||
- Language-independent (no pre-tokenization)
|
||||
- Treats input as raw byte stream
|
||||
- Used by T5, ALBERT, XLNet
|
||||
|
||||
**Robust BPE**:
|
||||
- Dropout during training (randomly skip merges)
|
||||
- More robust tokenization at inference
|
||||
- Reduces overfitting to training data
|
||||
|
||||
## WordPiece
|
||||
|
||||
### Algorithm overview
|
||||
|
||||
WordPiece is similar to BPE but uses a different merge selection criterion.
|
||||
|
||||
**Training process**:
|
||||
1. Initialize vocabulary with all characters
|
||||
2. Count frequency of all token pairs
|
||||
3. Score each pair: `score = freq(pair) / (freq(first) × freq(second))`
|
||||
4. Merge pair with highest score
|
||||
5. Repeat until vocabulary size reached
|
||||
|
||||
### Why different scoring?
|
||||
|
||||
**BPE**: Merges most frequent pairs
|
||||
- "aa" appears 100 times → high priority
|
||||
- Even if 'a' appears 1000 times alone
|
||||
|
||||
**WordPiece**: Merges pairs that are semantically related
|
||||
- "aa" appears 100 times, 'a' appears 1000 times → low score (100 / (1000 × 1000))
|
||||
- "th" appears 50 times, 't' appears 60 times, 'h' appears 55 times → high score (50 / (60 × 55))
|
||||
- Prioritizes pairs that appear together more than expected
|
||||
|
||||
### Step-by-step example
|
||||
|
||||
**Corpus**:
|
||||
```
|
||||
low: 5
|
||||
lower: 2
|
||||
newest: 6
|
||||
widest: 3
|
||||
```
|
||||
|
||||
**Iteration 1**:
|
||||
```
|
||||
Count frequencies:
|
||||
'e': 11 (lower: 2, newest: 6, widest: 3)
|
||||
's': 9
|
||||
't': 9
|
||||
...
|
||||
|
||||
Count pairs:
|
||||
'e' + 's': 9 (newest: 6, widest: 3)
|
||||
'es' + 't': 9 (newest: 6, widest: 3)
|
||||
...
|
||||
|
||||
Compute scores:
|
||||
score('e' + 's') = 9 / (11 × 9) = 0.091
|
||||
score('es' + 't') = 9 / (9 × 9) = 0.111 ← highest score
|
||||
score('l' + 'o') = 7 / (7 × 9) = 0.111 ← tied
|
||||
|
||||
Choose: 'es' + 't' → 'est' (or 'lo' if tied)
|
||||
```
|
||||
|
||||
**Key difference**: WordPiece prioritizes rare combinations over frequent ones.
|
||||
|
||||
### Tokenization with WordPiece
|
||||
|
||||
Given vocabulary: `['##e', '##s', '##t', 'l', 'o', 'w', 'new', 'est', 'low']`
|
||||
|
||||
Tokenize "lowest":
|
||||
```
|
||||
Step 1: Find longest matching prefix
|
||||
'lowest' → 'low' (matches)
|
||||
|
||||
Step 2: Find longest match for remainder
|
||||
'est' → 'est' (matches)
|
||||
|
||||
Final: ['low', 'est']
|
||||
```
|
||||
|
||||
**If no match**:
|
||||
```
|
||||
Tokenize "unknownword":
|
||||
'unknownword' → no match
|
||||
'unknown' → no match
|
||||
'unkn' → no match
|
||||
'un' → no match
|
||||
'u' → no match
|
||||
→ [UNK]
|
||||
```
|
||||
|
||||
### Implementation
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import WordPiece
|
||||
from tokenizers.trainers import WordPieceTrainer
|
||||
from tokenizers.normalizers import BertNormalizer
|
||||
from tokenizers.pre_tokenizers import BertPreTokenizer
|
||||
|
||||
# Initialize BERT-style tokenizer
|
||||
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
|
||||
|
||||
# Normalization (lowercase, accent stripping)
|
||||
tokenizer.normalizer = BertNormalizer(lowercase=True)
|
||||
|
||||
# Pre-tokenization (whitespace + punctuation)
|
||||
tokenizer.pre_tokenizer = BertPreTokenizer()
|
||||
|
||||
# Configure trainer
|
||||
trainer = WordPieceTrainer(
|
||||
vocab_size=30522, # BERT vocab size
|
||||
min_frequency=2,
|
||||
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
|
||||
continuing_subword_prefix="##" # BERT uses ##
|
||||
)
|
||||
|
||||
# Train
|
||||
tokenizer.train_from_iterator(corpus, trainer=trainer)
|
||||
|
||||
# Use
|
||||
output = tokenizer.encode("Tokenization works great!")
|
||||
print(output.tokens) # ['token', '##ization', 'works', 'great', '!']
|
||||
```
|
||||
|
||||
### Subword prefix
|
||||
|
||||
**BERT uses `##` prefix**:
|
||||
```
|
||||
"unbelievable" → ['un', '##believ', '##able']
|
||||
```
|
||||
|
||||
**Why?**
|
||||
- Indicates token is a continuation
|
||||
- Allows reconstruction: remove ##, concatenate
|
||||
- Helps model distinguish word boundaries
|
||||
|
||||
### WordPiece advantages
|
||||
|
||||
**Semantic merges**:
|
||||
- Prioritizes meaningful combinations
|
||||
- "qu" has high score (always together)
|
||||
- "qx" has low score (rare combination)
|
||||
|
||||
**Better for morphology**:
|
||||
- Captures affixes: un-, -ing, -ed
|
||||
- Preserves word stems
|
||||
|
||||
**Trade-offs**:
|
||||
- Slower training than BPE
|
||||
- More memory (stores vocabulary, not merges)
|
||||
- Original implementation not open-source (HF reimplementation)
|
||||
|
||||
## Unigram
|
||||
|
||||
### Algorithm overview
|
||||
|
||||
Unigram works backward: start with large vocabulary, remove tokens.
|
||||
|
||||
**Training process**:
|
||||
1. Initialize with large vocabulary (all substrings)
|
||||
2. Estimate probability of each token (frequency-based)
|
||||
3. For each token, compute loss increase if removed
|
||||
4. Remove 10-20% of tokens with lowest loss impact
|
||||
5. Re-estimate probabilities
|
||||
6. Repeat until desired vocabulary size
|
||||
|
||||
### Probabilistic tokenization
|
||||
|
||||
**Unigram assumption**: Each token is independent.
|
||||
|
||||
Given vocabulary with probabilities:
|
||||
```
|
||||
P('low') = 0.02
|
||||
P('l') = 0.01
|
||||
P('o') = 0.015
|
||||
P('w') = 0.01
|
||||
P('est') = 0.03
|
||||
P('e') = 0.02
|
||||
P('s') = 0.015
|
||||
P('t') = 0.015
|
||||
```
|
||||
|
||||
Tokenize "lowest":
|
||||
```
|
||||
Option 1: ['low', 'est']
|
||||
P = P('low') × P('est') = 0.02 × 0.03 = 0.0006
|
||||
|
||||
Option 2: ['l', 'o', 'w', 'est']
|
||||
P = 0.01 × 0.015 × 0.01 × 0.03 = 0.000000045
|
||||
|
||||
Option 3: ['low', 'e', 's', 't']
|
||||
P = 0.02 × 0.02 × 0.015 × 0.015 = 0.0000009
|
||||
|
||||
Choose option 1 (highest probability)
|
||||
```
|
||||
|
||||
### Viterbi algorithm
|
||||
|
||||
Finding best tokenization is expensive (exponential possibilities).
|
||||
|
||||
**Viterbi algorithm** (dynamic programming):
|
||||
```python
|
||||
def tokenize_viterbi(word, vocab, probs):
|
||||
n = len(word)
|
||||
# dp[i] = (best_prob, best_tokens) for word[:i]
|
||||
dp = [{} for _ in range(n + 1)]
|
||||
dp[0] = (0.0, []) # log probability
|
||||
|
||||
for i in range(1, n + 1):
|
||||
best_prob = float('-inf')
|
||||
best_tokens = []
|
||||
|
||||
# Try all possible last tokens
|
||||
for j in range(i):
|
||||
token = word[j:i]
|
||||
if token in vocab:
|
||||
prob = dp[j][0] + log(probs[token])
|
||||
if prob > best_prob:
|
||||
best_prob = prob
|
||||
best_tokens = dp[j][1] + [token]
|
||||
|
||||
dp[i] = (best_prob, best_tokens)
|
||||
|
||||
return dp[n][1]
|
||||
```
|
||||
|
||||
**Time complexity**: O(n² × vocab_size) vs O(2^n) brute force
|
||||
|
||||
### Implementation
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import Unigram
|
||||
from tokenizers.trainers import UnigramTrainer
|
||||
|
||||
# Initialize
|
||||
tokenizer = Tokenizer(Unigram())
|
||||
|
||||
# Configure trainer
|
||||
trainer = UnigramTrainer(
|
||||
vocab_size=8000,
|
||||
special_tokens=["<unk>", "<s>", "</s>"],
|
||||
unk_token="<unk>",
|
||||
max_piece_length=16, # Max token length
|
||||
n_sub_iterations=2, # EM iterations
|
||||
shrinking_factor=0.75 # Remove 25% each iteration
|
||||
)
|
||||
|
||||
# Train
|
||||
tokenizer.train_from_iterator(corpus, trainer=trainer)
|
||||
|
||||
# Use
|
||||
output = tokenizer.encode("Tokenization with Unigram")
|
||||
print(output.tokens) # ['▁Token', 'ization', '▁with', '▁Un', 'igram']
|
||||
```
|
||||
|
||||
### Unigram advantages
|
||||
|
||||
**Probabilistic**:
|
||||
- Multiple valid tokenizations
|
||||
- Can sample different tokenizations (data augmentation)
|
||||
|
||||
**Subword regularization**:
|
||||
```python
|
||||
# Sample different tokenizations
|
||||
for _ in range(3):
|
||||
tokens = tokenizer.encode("tokenization", is_pretokenized=False).tokens
|
||||
print(tokens)
|
||||
|
||||
# Output (different each time):
|
||||
# ['token', 'ization']
|
||||
# ['tok', 'en', 'ization']
|
||||
# ['token', 'iz', 'ation']
|
||||
```
|
||||
|
||||
**Language-independent**:
|
||||
- No word boundaries needed
|
||||
- Works for CJK languages (Chinese, Japanese, Korean)
|
||||
- Treats input as character stream
|
||||
|
||||
**Trade-offs**:
|
||||
- Slower training (EM algorithm)
|
||||
- More hyperparameters
|
||||
- Larger model (stores probabilities)
|
||||
|
||||
## Algorithm comparison
|
||||
|
||||
### Training speed
|
||||
|
||||
| Algorithm | Small (10MB) | Medium (100MB) | Large (1GB) |
|
||||
|------------|--------------|----------------|-------------|
|
||||
| BPE | 10-15 sec | 1-2 min | 10-20 min |
|
||||
| WordPiece | 15-20 sec | 2-3 min | 15-30 min |
|
||||
| Unigram | 20-30 sec | 3-5 min | 30-60 min |
|
||||
|
||||
**Tested on**: 16-core CPU, 30k vocab
|
||||
|
||||
### Tokenization quality
|
||||
|
||||
Tested on English Wikipedia (perplexity measurement):
|
||||
|
||||
| Algorithm | Vocab Size | Tokens/Word | Unknown Rate |
|
||||
|------------|------------|-------------|--------------|
|
||||
| BPE | 30k | 1.3 | 0.5% |
|
||||
| WordPiece | 30k | 1.2 | 1.2% |
|
||||
| Unigram | 8k | 1.5 | 0.3% |
|
||||
|
||||
**Key observations**:
|
||||
- WordPiece: Slightly better compression
|
||||
- BPE: Lower unknown rate
|
||||
- Unigram: Smallest vocab, good coverage
|
||||
|
||||
### Compression ratio
|
||||
|
||||
Characters per token (higher = better compression):
|
||||
|
||||
| Language | BPE (30k) | WordPiece (30k) | Unigram (8k) |
|
||||
|----------|-----------|-----------------|--------------|
|
||||
| English | 4.2 | 4.5 | 3.8 |
|
||||
| Chinese | 2.1 | 2.3 | 2.5 |
|
||||
| Arabic | 3.5 | 3.8 | 3.2 |
|
||||
|
||||
**Best for each**:
|
||||
- English: WordPiece
|
||||
- Chinese: Unigram (language-independent)
|
||||
- Arabic: WordPiece
|
||||
|
||||
### Use case recommendations
|
||||
|
||||
**BPE** - Best for:
|
||||
- English language models
|
||||
- Code (handles symbols well)
|
||||
- Fast training needed
|
||||
- **Models**: GPT-2, GPT-3, RoBERTa, BART
|
||||
|
||||
**WordPiece** - Best for:
|
||||
- Masked language modeling (BERT-style)
|
||||
- Morphologically rich languages
|
||||
- Semantic understanding tasks
|
||||
- **Models**: BERT, DistilBERT, ELECTRA
|
||||
|
||||
**Unigram** - Best for:
|
||||
- Multilingual models
|
||||
- Languages without word boundaries (CJK)
|
||||
- Data augmentation via subword regularization
|
||||
- **Models**: T5, ALBERT, XLNet (via SentencePiece)
|
||||
|
||||
## Advanced topics
|
||||
|
||||
### Handling rare words
|
||||
|
||||
**BPE approach**:
|
||||
```
|
||||
"antidisestablishmentarianism"
|
||||
→ ['anti', 'dis', 'establish', 'ment', 'arian', 'ism']
|
||||
```
|
||||
|
||||
**WordPiece approach**:
|
||||
```
|
||||
"antidisestablishmentarianism"
|
||||
→ ['anti', '##dis', '##establish', '##ment', '##arian', '##ism']
|
||||
```
|
||||
|
||||
**Unigram approach**:
|
||||
```
|
||||
"antidisestablishmentarianism"
|
||||
→ ['▁anti', 'dis', 'establish', 'ment', 'arian', 'ism']
|
||||
```
|
||||
|
||||
### Handling numbers
|
||||
|
||||
**Challenge**: Infinite number combinations
|
||||
|
||||
**BPE solution**: Byte-level (handles any digit sequence)
|
||||
```python
|
||||
tokenizer = Tokenizer(BPE())
|
||||
tokenizer.pre_tokenizer = ByteLevel()
|
||||
|
||||
# Handles any number
|
||||
"123456789" → byte-level tokens
|
||||
```
|
||||
|
||||
**WordPiece solution**: Digit pre-tokenization
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import Digits
|
||||
|
||||
# Split digits individually or as groups
|
||||
tokenizer.pre_tokenizer = Digits(individual_digits=True)
|
||||
|
||||
"123" → ['1', '2', '3']
|
||||
```
|
||||
|
||||
**Unigram solution**: Learns common number patterns
|
||||
```python
|
||||
# Learns patterns during training
|
||||
"2023" → ['202', '3'] or ['20', '23']
|
||||
```
|
||||
|
||||
### Handling case sensitivity
|
||||
|
||||
**Lowercase (BERT)**:
|
||||
```python
|
||||
from tokenizers.normalizers import Lowercase
|
||||
|
||||
tokenizer.normalizer = Lowercase()
|
||||
|
||||
"Hello WORLD" → "hello world" → ['hello', 'world']
|
||||
```
|
||||
|
||||
**Preserve case (GPT-2)**:
|
||||
```python
|
||||
# No case normalization
|
||||
tokenizer.normalizer = None
|
||||
|
||||
"Hello WORLD" → ['Hello', 'WORLD']
|
||||
```
|
||||
|
||||
**Cased tokens (RoBERTa)**:
|
||||
```python
|
||||
# Learns separate tokens for different cases
|
||||
Vocabulary: ['Hello', 'hello', 'HELLO', 'world', 'WORLD']
|
||||
```
|
||||
|
||||
### Handling emojis and special characters
|
||||
|
||||
**Byte-level (GPT-2)**:
|
||||
```python
|
||||
tokenizer.pre_tokenizer = ByteLevel()
|
||||
|
||||
"Hello 🌍 👋" → byte-level representation (always works)
|
||||
```
|
||||
|
||||
**Unicode normalization**:
|
||||
```python
|
||||
from tokenizers.normalizers import NFKC
|
||||
|
||||
tokenizer.normalizer = NFKC()
|
||||
|
||||
"é" (composed) ↔ "é" (decomposed) → normalized to one form
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Issue: Poor subword splitting
|
||||
|
||||
**Symptom**:
|
||||
```
|
||||
"running" → ['r', 'u', 'n', 'n', 'i', 'n', 'g'] (too granular)
|
||||
```
|
||||
|
||||
**Solutions**:
|
||||
1. Increase vocabulary size
|
||||
2. Train longer (more merge iterations)
|
||||
3. Lower `min_frequency` threshold
|
||||
|
||||
### Issue: Too many unknown tokens
|
||||
|
||||
**Symptom**:
|
||||
```
|
||||
5% of tokens are [UNK]
|
||||
```
|
||||
|
||||
**Solutions**:
|
||||
1. Increase vocabulary size
|
||||
2. Use byte-level BPE (no UNK possible)
|
||||
3. Verify training corpus is representative
|
||||
|
||||
### Issue: Inconsistent tokenization
|
||||
|
||||
**Symptom**:
|
||||
```
|
||||
"running" → ['run', 'ning']
|
||||
"runner" → ['r', 'u', 'n', 'n', 'e', 'r']
|
||||
```
|
||||
|
||||
**Solutions**:
|
||||
1. Check normalization consistency
|
||||
2. Ensure pre-tokenization is deterministic
|
||||
3. Use Unigram for probabilistic variance
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Match algorithm to model architecture**:
|
||||
- BERT-style → WordPiece
|
||||
- GPT-style → BPE
|
||||
- T5-style → Unigram
|
||||
|
||||
2. **Use byte-level for multilingual**:
|
||||
- Handles any Unicode
|
||||
- No unknown tokens
|
||||
|
||||
3. **Test on representative data**:
|
||||
- Measure compression ratio
|
||||
- Check unknown token rate
|
||||
- Inspect sample tokenizations
|
||||
|
||||
4. **Version control tokenizers**:
|
||||
- Save with model
|
||||
- Document special tokens
|
||||
- Track vocabulary changes
|
||||
@@ -0,0 +1,637 @@
|
||||
# Transformers Integration
|
||||
|
||||
Complete guide to using HuggingFace Tokenizers with the Transformers library.
|
||||
|
||||
## AutoTokenizer
|
||||
|
||||
The easiest way to load tokenizers.
|
||||
|
||||
### Loading pretrained tokenizers
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Load from HuggingFace Hub
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
# Check if using fast tokenizer (Rust-based)
|
||||
print(tokenizer.is_fast) # True
|
||||
|
||||
# Access underlying tokenizers.Tokenizer
|
||||
if tokenizer.is_fast:
|
||||
fast_tokenizer = tokenizer.backend_tokenizer
|
||||
print(type(fast_tokenizer)) # <class 'tokenizers.Tokenizer'>
|
||||
```
|
||||
|
||||
### Fast vs slow tokenizers
|
||||
|
||||
| Feature | Fast (Rust) | Slow (Python) |
|
||||
|--------------------------|----------------|---------------|
|
||||
| Speed | 5-10× faster | Baseline |
|
||||
| Alignment tracking | ✅ Full support | ❌ Limited |
|
||||
| Batch processing | ✅ Optimized | ⚠️ Slower |
|
||||
| Offset mapping | ✅ Yes | ❌ No |
|
||||
| Installation | `tokenizers` | Built-in |
|
||||
|
||||
**Always use fast tokenizers when available.**
|
||||
|
||||
### Check available tokenizers
|
||||
|
||||
```python
|
||||
from transformers import TOKENIZER_MAPPING
|
||||
|
||||
# List all fast tokenizers
|
||||
for config_class, (slow, fast) in TOKENIZER_MAPPING.items():
|
||||
if fast is not None:
|
||||
print(f"{config_class.__name__}: {fast.__name__}")
|
||||
```
|
||||
|
||||
## PreTrainedTokenizerFast
|
||||
|
||||
Wrap custom tokenizers for transformers.
|
||||
|
||||
### Convert custom tokenizer
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
# Train custom tokenizer
|
||||
tokenizer = Tokenizer(BPE())
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=30000,
|
||||
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
|
||||
)
|
||||
tokenizer.train(files=["corpus.txt"], trainer=trainer)
|
||||
|
||||
# Save tokenizer
|
||||
tokenizer.save("my-tokenizer.json")
|
||||
|
||||
# Wrap for transformers
|
||||
transformers_tokenizer = PreTrainedTokenizerFast(
|
||||
tokenizer_file="my-tokenizer.json",
|
||||
unk_token="[UNK]",
|
||||
sep_token="[SEP]",
|
||||
pad_token="[PAD]",
|
||||
cls_token="[CLS]",
|
||||
mask_token="[MASK]"
|
||||
)
|
||||
|
||||
# Save in transformers format
|
||||
transformers_tokenizer.save_pretrained("my-tokenizer")
|
||||
```
|
||||
|
||||
**Result**: Directory with `tokenizer.json` + `tokenizer_config.json` + `special_tokens_map.json`
|
||||
|
||||
### Use like any transformers tokenizer
|
||||
|
||||
```python
|
||||
# Load
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained("my-tokenizer")
|
||||
|
||||
# Encode with all transformers features
|
||||
outputs = tokenizer(
|
||||
"Hello world",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=128,
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
print(outputs.keys())
|
||||
# dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])
|
||||
```
|
||||
|
||||
## Special tokens
|
||||
|
||||
### Default special tokens
|
||||
|
||||
| Model Family | CLS/BOS | SEP/EOS | PAD | UNK | MASK |
|
||||
|--------------|---------|---------------|---------|---------|---------|
|
||||
| BERT | [CLS] | [SEP] | [PAD] | [UNK] | [MASK] |
|
||||
| GPT-2 | - | <\|endoftext\|> | <\|endoftext\|> | <\|endoftext\|> | - |
|
||||
| RoBERTa | <s> | </s> | <pad> | <unk> | <mask> |
|
||||
| T5 | - | </s> | <pad> | <unk> | - |
|
||||
|
||||
### Adding special tokens
|
||||
|
||||
```python
|
||||
# Add new special tokens
|
||||
special_tokens_dict = {
|
||||
"additional_special_tokens": ["<|image|>", "<|video|>", "<|audio|>"]
|
||||
}
|
||||
|
||||
num_added_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
||||
print(f"Added {num_added_tokens} tokens")
|
||||
|
||||
# Resize model embeddings
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
# Use new tokens
|
||||
text = "This is an image: <|image|>"
|
||||
tokens = tokenizer.encode(text)
|
||||
```
|
||||
|
||||
### Adding regular tokens
|
||||
|
||||
```python
|
||||
# Add domain-specific tokens
|
||||
new_tokens = ["COVID-19", "mRNA", "vaccine"]
|
||||
num_added = tokenizer.add_tokens(new_tokens)
|
||||
|
||||
# These are NOT special tokens (can be split if needed)
|
||||
tokenizer.add_tokens(new_tokens, special_tokens=False)
|
||||
|
||||
# These ARE special tokens (never split)
|
||||
tokenizer.add_tokens(new_tokens, special_tokens=True)
|
||||
```
|
||||
|
||||
## Encoding and decoding
|
||||
|
||||
### Basic encoding
|
||||
|
||||
```python
|
||||
# Single sentence
|
||||
text = "Hello, how are you?"
|
||||
encoded = tokenizer(text)
|
||||
|
||||
print(encoded)
|
||||
# {'input_ids': [101, 7592, 1010, 2129, 2024, 2017, 1029, 102],
|
||||
# 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0],
|
||||
# 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}
|
||||
```
|
||||
|
||||
### Batch encoding
|
||||
|
||||
```python
|
||||
# Multiple sentences
|
||||
texts = ["Hello world", "How are you?", "I am fine"]
|
||||
encoded = tokenizer(texts, padding=True, truncation=True, max_length=10)
|
||||
|
||||
print(encoded['input_ids'])
|
||||
# [[101, 7592, 2088, 102, 0, 0, 0, 0, 0, 0],
|
||||
# [101, 2129, 2024, 2017, 1029, 102, 0, 0, 0, 0],
|
||||
# [101, 1045, 2572, 2986, 102, 0, 0, 0, 0, 0]]
|
||||
```
|
||||
|
||||
### Return tensors
|
||||
|
||||
```python
|
||||
# Return PyTorch tensors
|
||||
outputs = tokenizer("Hello world", return_tensors="pt")
|
||||
print(outputs['input_ids'].shape) # torch.Size([1, 5])
|
||||
|
||||
# Return TensorFlow tensors
|
||||
outputs = tokenizer("Hello world", return_tensors="tf")
|
||||
|
||||
# Return NumPy arrays
|
||||
outputs = tokenizer("Hello world", return_tensors="np")
|
||||
|
||||
# Return lists (default)
|
||||
outputs = tokenizer("Hello world", return_tensors=None)
|
||||
```
|
||||
|
||||
### Decoding
|
||||
|
||||
```python
|
||||
# Decode token IDs
|
||||
ids = [101, 7592, 2088, 102]
|
||||
text = tokenizer.decode(ids)
|
||||
print(text) # "[CLS] hello world [SEP]"
|
||||
|
||||
# Skip special tokens
|
||||
text = tokenizer.decode(ids, skip_special_tokens=True)
|
||||
print(text) # "hello world"
|
||||
|
||||
# Batch decode
|
||||
batch_ids = [[101, 7592, 102], [101, 2088, 102]]
|
||||
texts = tokenizer.batch_decode(batch_ids, skip_special_tokens=True)
|
||||
print(texts) # ["hello", "world"]
|
||||
```
|
||||
|
||||
## Padding and truncation
|
||||
|
||||
### Padding strategies
|
||||
|
||||
```python
|
||||
# Pad to max length in batch
|
||||
tokenizer(texts, padding="longest")
|
||||
|
||||
# Pad to model max length
|
||||
tokenizer(texts, padding="max_length", max_length=128)
|
||||
|
||||
# No padding
|
||||
tokenizer(texts, padding=False)
|
||||
|
||||
# Pad to multiple of value (for efficient computation)
|
||||
tokenizer(texts, padding="max_length", max_length=128, pad_to_multiple_of=8)
|
||||
# Result: length will be 128 (already multiple of 8)
|
||||
```
|
||||
|
||||
### Truncation strategies
|
||||
|
||||
```python
|
||||
# Truncate to max length
|
||||
tokenizer(text, truncation=True, max_length=10)
|
||||
|
||||
# Only truncate first sequence (for pairs)
|
||||
tokenizer(text1, text2, truncation="only_first", max_length=20)
|
||||
|
||||
# Only truncate second sequence
|
||||
tokenizer(text1, text2, truncation="only_second", max_length=20)
|
||||
|
||||
# Truncate longest first (default for pairs)
|
||||
tokenizer(text1, text2, truncation="longest_first", max_length=20)
|
||||
|
||||
# No truncation (error if too long)
|
||||
tokenizer(text, truncation=False)
|
||||
```
|
||||
|
||||
### Stride for long documents
|
||||
|
||||
```python
|
||||
# For documents longer than max_length
|
||||
text = "Very long document " * 1000
|
||||
|
||||
# Encode with overlap
|
||||
encodings = tokenizer(
|
||||
text,
|
||||
max_length=512,
|
||||
stride=128, # Overlap between chunks
|
||||
truncation=True,
|
||||
return_overflowing_tokens=True,
|
||||
return_offsets_mapping=True
|
||||
)
|
||||
|
||||
# Get all chunks
|
||||
num_chunks = len(encodings['input_ids'])
|
||||
print(f"Split into {num_chunks} chunks")
|
||||
|
||||
# Each chunk overlaps by stride tokens
|
||||
for i, chunk in enumerate(encodings['input_ids']):
|
||||
print(f"Chunk {i}: {len(chunk)} tokens")
|
||||
```
|
||||
|
||||
**Use case**: Long document QA, sliding window inference
|
||||
|
||||
## Alignment and offsets
|
||||
|
||||
### Offset mapping
|
||||
|
||||
```python
|
||||
# Get character offsets for each token
|
||||
encoded = tokenizer("Hello, world!", return_offsets_mapping=True)
|
||||
|
||||
for token, (start, end) in zip(
|
||||
encoded.tokens(),
|
||||
encoded['offset_mapping'][0]
|
||||
):
|
||||
print(f"{token:10s} → [{start:2d}, {end:2d})")
|
||||
|
||||
# Output:
|
||||
# [CLS] → [ 0, 0)
|
||||
# Hello → [ 0, 5)
|
||||
# , → [ 5, 6)
|
||||
# world → [ 7, 12)
|
||||
# ! → [12, 13)
|
||||
# [SEP] → [ 0, 0)
|
||||
```
|
||||
|
||||
### Word IDs
|
||||
|
||||
```python
|
||||
# Get word index for each token
|
||||
encoded = tokenizer("Hello world", return_offsets_mapping=True)
|
||||
word_ids = encoded.word_ids()
|
||||
|
||||
print(word_ids)
|
||||
# [None, 0, 1, None]
|
||||
# None = special token, 0 = first word, 1 = second word
|
||||
```
|
||||
|
||||
**Use case**: Token classification (NER, POS tagging)
|
||||
|
||||
### Character to token mapping
|
||||
|
||||
```python
|
||||
text = "Machine learning is awesome"
|
||||
encoded = tokenizer(text, return_offsets_mapping=True)
|
||||
|
||||
# Find token for character position
|
||||
char_pos = 8 # "l" in "learning"
|
||||
token_idx = encoded.char_to_token(char_pos)
|
||||
|
||||
print(f"Character {char_pos} is in token {token_idx}: {encoded.tokens()[token_idx]}")
|
||||
# Character 8 is in token 2: learning
|
||||
```
|
||||
|
||||
**Use case**: Question answering (map answer character span to tokens)
|
||||
|
||||
### Sequence pairs
|
||||
|
||||
```python
|
||||
# Encode sentence pair
|
||||
encoded = tokenizer("Question here", "Answer here", return_offsets_mapping=True)
|
||||
|
||||
# Get sequence IDs (which sequence each token belongs to)
|
||||
sequence_ids = encoded.sequence_ids()
|
||||
print(sequence_ids)
|
||||
# [None, 0, 0, 0, None, 1, 1, 1, None]
|
||||
# None = special token, 0 = question, 1 = answer
|
||||
```
|
||||
|
||||
## Model integration
|
||||
|
||||
### Use with transformers models
|
||||
|
||||
```python
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
import torch
|
||||
|
||||
# Load model and tokenizer
|
||||
model = AutoModel.from_pretrained("bert-base-uncased")
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
# Tokenize
|
||||
text = "Hello world"
|
||||
inputs = tokenizer(text, return_tensors="pt")
|
||||
|
||||
# Forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# Get embeddings
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
print(last_hidden_state.shape) # [1, seq_len, hidden_size]
|
||||
```
|
||||
|
||||
### Custom model with custom tokenizer
|
||||
|
||||
```python
|
||||
from transformers import BertConfig, BertModel
|
||||
|
||||
# Train custom tokenizer
|
||||
from tokenizers import Tokenizer, models, trainers
|
||||
tokenizer = Tokenizer(models.BPE())
|
||||
trainer = trainers.BpeTrainer(vocab_size=30000)
|
||||
tokenizer.train(files=["data.txt"], trainer=trainer)
|
||||
|
||||
# Wrap for transformers
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
fast_tokenizer = PreTrainedTokenizerFast(
|
||||
tokenizer_object=tokenizer,
|
||||
unk_token="[UNK]",
|
||||
pad_token="[PAD]"
|
||||
)
|
||||
|
||||
# Create model with custom vocab size
|
||||
config = BertConfig(vocab_size=30000)
|
||||
model = BertModel(config)
|
||||
|
||||
# Use together
|
||||
inputs = fast_tokenizer("Hello world", return_tensors="pt")
|
||||
outputs = model(**inputs)
|
||||
```
|
||||
|
||||
### Save and load together
|
||||
|
||||
```python
|
||||
# Save both
|
||||
model.save_pretrained("my-model")
|
||||
tokenizer.save_pretrained("my-model")
|
||||
|
||||
# Directory structure:
|
||||
# my-model/
|
||||
# ├── config.json
|
||||
# ├── pytorch_model.bin
|
||||
# ├── tokenizer.json
|
||||
# ├── tokenizer_config.json
|
||||
# └── special_tokens_map.json
|
||||
|
||||
# Load both
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
model = AutoModel.from_pretrained("my-model")
|
||||
tokenizer = AutoTokenizer.from_pretrained("my-model")
|
||||
```
|
||||
|
||||
## Advanced features
|
||||
|
||||
### Multimodal tokenization
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# LLaVA-style (image + text)
|
||||
tokenizer = AutoTokenizer.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
||||
|
||||
# Add image placeholder token
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": ["<image>"]})
|
||||
|
||||
# Use in prompt
|
||||
text = "Describe this image: <image>"
|
||||
inputs = tokenizer(text, return_tensors="pt")
|
||||
```
|
||||
|
||||
### Template formatting
|
||||
|
||||
```python
|
||||
# Chat template
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Hi! How can I help?"},
|
||||
{"role": "user", "content": "What's the weather?"}
|
||||
]
|
||||
|
||||
# Apply chat template (if tokenizer has one)
|
||||
if hasattr(tokenizer, "apply_chat_template"):
|
||||
text = tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
inputs = tokenizer(text, return_tensors="pt")
|
||||
```
|
||||
|
||||
### Custom template
|
||||
|
||||
```python
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
tokenizer = PreTrainedTokenizerFast(tokenizer_file="tokenizer.json")
|
||||
|
||||
# Define chat template
|
||||
tokenizer.chat_template = """
|
||||
{%- for message in messages %}
|
||||
{%- if message['role'] == 'system' %}
|
||||
System: {{ message['content'] }}\\n
|
||||
{%- elif message['role'] == 'user' %}
|
||||
User: {{ message['content'] }}\\n
|
||||
{%- elif message['role'] == 'assistant' %}
|
||||
Assistant: {{ message['content'] }}\\n
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
Assistant:
|
||||
"""
|
||||
|
||||
# Use template
|
||||
text = tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
```
|
||||
|
||||
## Performance optimization
|
||||
|
||||
### Batch processing
|
||||
|
||||
```python
|
||||
# Process large datasets efficiently
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("imdb", split="train[:1000]")
|
||||
|
||||
# Tokenize in batches
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(
|
||||
examples["text"],
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=512
|
||||
)
|
||||
|
||||
# Map over dataset (batched)
|
||||
tokenized_dataset = dataset.map(
|
||||
tokenize_function,
|
||||
batched=True,
|
||||
batch_size=1000,
|
||||
num_proc=4 # Parallel processing
|
||||
)
|
||||
```
|
||||
|
||||
### Caching
|
||||
|
||||
```python
|
||||
# Enable caching for repeated tokenization
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"bert-base-uncased",
|
||||
use_fast=True,
|
||||
cache_dir="./cache" # Cache tokenizer files
|
||||
)
|
||||
|
||||
# Tokenize with caching
|
||||
from functools import lru_cache
|
||||
|
||||
@lru_cache(maxsize=10000)
|
||||
def cached_tokenize(text):
|
||||
return tuple(tokenizer.encode(text))
|
||||
|
||||
# Reuses cached results for repeated inputs
|
||||
```
|
||||
|
||||
### Memory efficiency
|
||||
|
||||
```python
|
||||
# For very large datasets, use streaming
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("pile", split="train", streaming=True)
|
||||
|
||||
def process_batch(batch):
|
||||
# Tokenize
|
||||
tokens = tokenizer(batch["text"], truncation=True, max_length=512)
|
||||
|
||||
# Process tokens...
|
||||
|
||||
return tokens
|
||||
|
||||
# Process in chunks (memory efficient)
|
||||
for batch in dataset.batch(batch_size=1000):
|
||||
processed = process_batch(batch)
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Issue: Tokenizer not fast
|
||||
|
||||
**Symptom**:
|
||||
```python
|
||||
tokenizer.is_fast # False
|
||||
```
|
||||
|
||||
**Solution**: Install tokenizers library
|
||||
```bash
|
||||
pip install tokenizers
|
||||
```
|
||||
|
||||
### Issue: Special tokens not working
|
||||
|
||||
**Symptom**: Special tokens are split into subwords
|
||||
|
||||
**Solution**: Add as special tokens, not regular tokens
|
||||
```python
|
||||
# Wrong
|
||||
tokenizer.add_tokens(["<|image|>"])
|
||||
|
||||
# Correct
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": ["<|image|>"]})
|
||||
```
|
||||
|
||||
### Issue: Offset mapping not available
|
||||
|
||||
**Symptom**:
|
||||
```python
|
||||
tokenizer("text", return_offsets_mapping=True)
|
||||
# Error: return_offsets_mapping not supported
|
||||
```
|
||||
|
||||
**Solution**: Use fast tokenizer
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Load fast version
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", use_fast=True)
|
||||
```
|
||||
|
||||
### Issue: Padding inconsistent
|
||||
|
||||
**Symptom**: Some sequences padded, others not
|
||||
|
||||
**Solution**: Specify padding strategy
|
||||
```python
|
||||
# Explicit padding
|
||||
tokenizer(
|
||||
texts,
|
||||
padding="max_length", # or "longest"
|
||||
max_length=128
|
||||
)
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Always use fast tokenizers**:
|
||||
- 5-10× faster
|
||||
- Full alignment tracking
|
||||
- Better batch processing
|
||||
|
||||
2. **Save tokenizer with model**:
|
||||
- Ensures reproducibility
|
||||
- Prevents version mismatches
|
||||
|
||||
3. **Use batch processing for datasets**:
|
||||
- Tokenize with `.map(batched=True)`
|
||||
- Set `num_proc` for parallelism
|
||||
|
||||
4. **Enable caching for repeated inputs**:
|
||||
- Use `lru_cache` for inference
|
||||
- Cache tokenizer files with `cache_dir`
|
||||
|
||||
5. **Handle special tokens properly**:
|
||||
- Use `add_special_tokens()` for never-split tokens
|
||||
- Resize embeddings after adding tokens
|
||||
|
||||
6. **Test alignment for downstream tasks**:
|
||||
- Verify `offset_mapping` is correct
|
||||
- Test `char_to_token()` on samples
|
||||
|
||||
7. **Version control tokenizer config**:
|
||||
- Save `tokenizer_config.json`
|
||||
- Document custom templates
|
||||
- Track vocabulary changes
|
||||
@@ -0,0 +1,723 @@
|
||||
# Tokenization Pipeline Components
|
||||
|
||||
Complete guide to normalizers, pre-tokenizers, models, post-processors, and decoders.
|
||||
|
||||
## Pipeline overview
|
||||
|
||||
**Full tokenization pipeline**:
|
||||
```
|
||||
Raw Text
|
||||
↓
|
||||
Normalization (cleaning, lowercasing)
|
||||
↓
|
||||
Pre-tokenization (split into words)
|
||||
↓
|
||||
Model (apply BPE/WordPiece/Unigram)
|
||||
↓
|
||||
Post-processing (add special tokens)
|
||||
↓
|
||||
Token IDs
|
||||
```
|
||||
|
||||
**Decoding reverses the process**:
|
||||
```
|
||||
Token IDs
|
||||
↓
|
||||
Decoder (handle special encodings)
|
||||
↓
|
||||
Raw Text
|
||||
```
|
||||
|
||||
## Normalizers
|
||||
|
||||
Clean and standardize input text.
|
||||
|
||||
### Common normalizers
|
||||
|
||||
**Lowercase**:
|
||||
```python
|
||||
from tokenizers.normalizers import Lowercase
|
||||
|
||||
tokenizer.normalizer = Lowercase()
|
||||
|
||||
# Input: "Hello WORLD"
|
||||
# Output: "hello world"
|
||||
```
|
||||
|
||||
**Unicode normalization**:
|
||||
```python
|
||||
from tokenizers.normalizers import NFD, NFC, NFKD, NFKC
|
||||
|
||||
# NFD: Canonical decomposition
|
||||
tokenizer.normalizer = NFD()
|
||||
# "é" → "e" + "́" (separate characters)
|
||||
|
||||
# NFC: Canonical composition (default)
|
||||
tokenizer.normalizer = NFC()
|
||||
# "e" + "́" → "é" (composed)
|
||||
|
||||
# NFKD: Compatibility decomposition
|
||||
tokenizer.normalizer = NFKD()
|
||||
# "fi" → "f" + "i"
|
||||
|
||||
# NFKC: Compatibility composition
|
||||
tokenizer.normalizer = NFKC()
|
||||
# Most aggressive normalization
|
||||
```
|
||||
|
||||
**Strip accents**:
|
||||
```python
|
||||
from tokenizers.normalizers import StripAccents
|
||||
|
||||
tokenizer.normalizer = StripAccents()
|
||||
|
||||
# Input: "café"
|
||||
# Output: "cafe"
|
||||
```
|
||||
|
||||
**Whitespace handling**:
|
||||
```python
|
||||
from tokenizers.normalizers import Strip, StripAccents
|
||||
|
||||
# Remove leading/trailing whitespace
|
||||
tokenizer.normalizer = Strip()
|
||||
|
||||
# Input: " hello "
|
||||
# Output: "hello"
|
||||
```
|
||||
|
||||
**Replace patterns**:
|
||||
```python
|
||||
from tokenizers.normalizers import Replace
|
||||
|
||||
# Replace newlines with spaces
|
||||
tokenizer.normalizer = Replace("\\n", " ")
|
||||
|
||||
# Input: "hello\\nworld"
|
||||
# Output: "hello world"
|
||||
```
|
||||
|
||||
### Combining normalizers
|
||||
|
||||
```python
|
||||
from tokenizers.normalizers import Sequence, NFD, Lowercase, StripAccents
|
||||
|
||||
# BERT-style normalization
|
||||
tokenizer.normalizer = Sequence([
|
||||
NFD(), # Unicode decomposition
|
||||
Lowercase(), # Convert to lowercase
|
||||
StripAccents() # Remove accents
|
||||
])
|
||||
|
||||
# Input: "Café au Lait"
|
||||
# After NFD: "Café au Lait" (e + ́)
|
||||
# After Lowercase: "café au lait"
|
||||
# After StripAccents: "cafe au lait"
|
||||
```
|
||||
|
||||
### Use case examples
|
||||
|
||||
**Case-insensitive model (BERT)**:
|
||||
```python
|
||||
from tokenizers.normalizers import BertNormalizer
|
||||
|
||||
# All-in-one BERT normalization
|
||||
tokenizer.normalizer = BertNormalizer(
|
||||
clean_text=True, # Remove control characters
|
||||
handle_chinese_chars=True, # Add spaces around Chinese
|
||||
strip_accents=True, # Remove accents
|
||||
lowercase=True # Lowercase
|
||||
)
|
||||
```
|
||||
|
||||
**Case-sensitive model (GPT-2)**:
|
||||
```python
|
||||
# Minimal normalization
|
||||
tokenizer.normalizer = NFC() # Only normalize Unicode
|
||||
```
|
||||
|
||||
**Multilingual (mBERT)**:
|
||||
```python
|
||||
# Preserve scripts, normalize form
|
||||
tokenizer.normalizer = NFKC()
|
||||
```
|
||||
|
||||
## Pre-tokenizers
|
||||
|
||||
Split text into word-like units before tokenization.
|
||||
|
||||
### Whitespace splitting
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import Whitespace
|
||||
|
||||
tokenizer.pre_tokenizer = Whitespace()
|
||||
|
||||
# Input: "Hello world! How are you?"
|
||||
# Output: [("Hello", (0, 5)), ("world!", (6, 12)), ("How", (13, 16)), ("are", (17, 20)), ("you?", (21, 25))]
|
||||
```
|
||||
|
||||
### Punctuation isolation
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import Punctuation
|
||||
|
||||
tokenizer.pre_tokenizer = Punctuation()
|
||||
|
||||
# Input: "Hello, world!"
|
||||
# Output: [("Hello", ...), (",", ...), ("world", ...), ("!", ...)]
|
||||
```
|
||||
|
||||
### Byte-level (GPT-2)
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import ByteLevel
|
||||
|
||||
tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=True)
|
||||
|
||||
# Input: "Hello world"
|
||||
# Output: Byte-level tokens with Ġ prefix for spaces
|
||||
# [("ĠHello", ...), ("Ġworld", ...)]
|
||||
```
|
||||
|
||||
**Key feature**: Handles ALL Unicode characters (256 byte combinations)
|
||||
|
||||
### Metaspace (SentencePiece)
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import Metaspace
|
||||
|
||||
tokenizer.pre_tokenizer = Metaspace(replacement="▁", add_prefix_space=True)
|
||||
|
||||
# Input: "Hello world"
|
||||
# Output: [("▁Hello", ...), ("▁world", ...)]
|
||||
```
|
||||
|
||||
**Used by**: T5, ALBERT (via SentencePiece)
|
||||
|
||||
### Digits splitting
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import Digits
|
||||
|
||||
# Split digits individually
|
||||
tokenizer.pre_tokenizer = Digits(individual_digits=True)
|
||||
|
||||
# Input: "Room 123"
|
||||
# Output: [("Room", ...), ("1", ...), ("2", ...), ("3", ...)]
|
||||
|
||||
# Keep digits together
|
||||
tokenizer.pre_tokenizer = Digits(individual_digits=False)
|
||||
|
||||
# Input: "Room 123"
|
||||
# Output: [("Room", ...), ("123", ...)]
|
||||
```
|
||||
|
||||
### BERT pre-tokenizer
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import BertPreTokenizer
|
||||
|
||||
tokenizer.pre_tokenizer = BertPreTokenizer()
|
||||
|
||||
# Splits on whitespace and punctuation, preserves CJK
|
||||
# Input: "Hello, 世界!"
|
||||
# Output: [("Hello", ...), (",", ...), ("世", ...), ("界", ...), ("!", ...)]
|
||||
```
|
||||
|
||||
### Combining pre-tokenizers
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import Sequence, Whitespace, Punctuation
|
||||
|
||||
tokenizer.pre_tokenizer = Sequence([
|
||||
Whitespace(), # Split on whitespace first
|
||||
Punctuation() # Then isolate punctuation
|
||||
])
|
||||
|
||||
# Input: "Hello, world!"
|
||||
# After Whitespace: [("Hello,", ...), ("world!", ...)]
|
||||
# After Punctuation: [("Hello", ...), (",", ...), ("world", ...), ("!", ...)]
|
||||
```
|
||||
|
||||
### Pre-tokenizer comparison
|
||||
|
||||
| Pre-tokenizer | Use Case | Example |
|
||||
|-------------------|---------------------------------|--------------------------------------------|
|
||||
| Whitespace | Simple English | "Hello world" → ["Hello", "world"] |
|
||||
| Punctuation | Isolate symbols | "world!" → ["world", "!"] |
|
||||
| ByteLevel | Multilingual, emojis | "🌍" → byte tokens |
|
||||
| Metaspace | SentencePiece-style | "Hello" → ["▁Hello"] |
|
||||
| BertPreTokenizer | BERT-style (CJK aware) | "世界" → ["世", "界"] |
|
||||
| Digits | Handle numbers | "123" → ["1", "2", "3"] or ["123"] |
|
||||
|
||||
## Models
|
||||
|
||||
Core tokenization algorithms.
|
||||
|
||||
### BPE Model
|
||||
|
||||
```python
|
||||
from tokenizers.models import BPE
|
||||
|
||||
model = BPE(
|
||||
vocab=None, # Or provide pre-built vocab
|
||||
merges=None, # Or provide merge rules
|
||||
unk_token="[UNK]", # Unknown token
|
||||
continuing_subword_prefix="",
|
||||
end_of_word_suffix="",
|
||||
fuse_unk=False # Keep unknown tokens separate
|
||||
)
|
||||
|
||||
tokenizer = Tokenizer(model)
|
||||
```
|
||||
|
||||
**Parameters**:
|
||||
- `vocab`: Dict of token → id
|
||||
- `merges`: List of merge rules `["a b", "ab c"]`
|
||||
- `unk_token`: Token for unknown words
|
||||
- `continuing_subword_prefix`: Prefix for subwords (empty for GPT-2)
|
||||
- `end_of_word_suffix`: Suffix for last subword (empty for GPT-2)
|
||||
|
||||
### WordPiece Model
|
||||
|
||||
```python
|
||||
from tokenizers.models import WordPiece
|
||||
|
||||
model = WordPiece(
|
||||
vocab=None,
|
||||
unk_token="[UNK]",
|
||||
max_input_chars_per_word=100, # Max word length
|
||||
continuing_subword_prefix="##" # BERT-style prefix
|
||||
)
|
||||
|
||||
tokenizer = Tokenizer(model)
|
||||
```
|
||||
|
||||
**Key difference**: Uses `##` prefix for continuing subwords.
|
||||
|
||||
### Unigram Model
|
||||
|
||||
```python
|
||||
from tokenizers.models import Unigram
|
||||
|
||||
model = Unigram(
|
||||
vocab=None, # List of (token, score) tuples
|
||||
unk_id=0, # ID for unknown token
|
||||
byte_fallback=False # Fall back to bytes if no match
|
||||
)
|
||||
|
||||
tokenizer = Tokenizer(model)
|
||||
```
|
||||
|
||||
**Probabilistic**: Selects tokenization with highest probability.
|
||||
|
||||
### WordLevel Model
|
||||
|
||||
```python
|
||||
from tokenizers.models import WordLevel
|
||||
|
||||
# Simple word-to-ID mapping (no subwords)
|
||||
model = WordLevel(
|
||||
vocab=None,
|
||||
unk_token="[UNK]"
|
||||
)
|
||||
|
||||
tokenizer = Tokenizer(model)
|
||||
```
|
||||
|
||||
**Warning**: Requires huge vocabulary (one token per word).
|
||||
|
||||
## Post-processors
|
||||
|
||||
Add special tokens and format output.
|
||||
|
||||
### Template processing
|
||||
|
||||
**BERT-style** (`[CLS] sentence [SEP]`):
|
||||
```python
|
||||
from tokenizers.processors import TemplateProcessing
|
||||
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="[CLS] $A [SEP]",
|
||||
pair="[CLS] $A [SEP] $B [SEP]",
|
||||
special_tokens=[
|
||||
("[CLS]", 101),
|
||||
("[SEP]", 102),
|
||||
],
|
||||
)
|
||||
|
||||
# Single sentence
|
||||
output = tokenizer.encode("Hello world")
|
||||
# [101, ..., 102] ([CLS] hello world [SEP])
|
||||
|
||||
# Sentence pair
|
||||
output = tokenizer.encode("Hello", "world")
|
||||
# [101, ..., 102, ..., 102] ([CLS] hello [SEP] world [SEP])
|
||||
```
|
||||
|
||||
**GPT-2 style** (`sentence <|endoftext|>`):
|
||||
```python
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="$A <|endoftext|>",
|
||||
special_tokens=[
|
||||
("<|endoftext|>", 50256),
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
**RoBERTa style** (`<s> sentence </s>`):
|
||||
```python
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="<s> $A </s>",
|
||||
pair="<s> $A </s> </s> $B </s>",
|
||||
special_tokens=[
|
||||
("<s>", 0),
|
||||
("</s>", 2),
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
**T5 style** (no special tokens):
|
||||
```python
|
||||
# T5 doesn't add special tokens via post-processor
|
||||
tokenizer.post_processor = None
|
||||
```
|
||||
|
||||
### RobertaProcessing
|
||||
|
||||
```python
|
||||
from tokenizers.processors import RobertaProcessing
|
||||
|
||||
tokenizer.post_processor = RobertaProcessing(
|
||||
sep=("</s>", 2),
|
||||
cls=("<s>", 0),
|
||||
add_prefix_space=True, # Add space before first token
|
||||
trim_offsets=True # Trim leading space from offsets
|
||||
)
|
||||
```
|
||||
|
||||
### ByteLevelProcessing
|
||||
|
||||
```python
|
||||
from tokenizers.processors import ByteLevel as ByteLevelProcessing
|
||||
|
||||
tokenizer.post_processor = ByteLevelProcessing(
|
||||
trim_offsets=True # Remove Ġ from offsets
|
||||
)
|
||||
```
|
||||
|
||||
## Decoders
|
||||
|
||||
Convert token IDs back to text.
|
||||
|
||||
### ByteLevel decoder
|
||||
|
||||
```python
|
||||
from tokenizers.decoders import ByteLevel
|
||||
|
||||
tokenizer.decoder = ByteLevel()
|
||||
|
||||
# Handles byte-level tokens
|
||||
# ["ĠHello", "Ġworld"] → "Hello world"
|
||||
```
|
||||
|
||||
### WordPiece decoder
|
||||
|
||||
```python
|
||||
from tokenizers.decoders import WordPiece
|
||||
|
||||
tokenizer.decoder = WordPiece(prefix="##")
|
||||
|
||||
# Removes ## prefix and concatenates
|
||||
# ["token", "##ization"] → "tokenization"
|
||||
```
|
||||
|
||||
### Metaspace decoder
|
||||
|
||||
```python
|
||||
from tokenizers.decoders import Metaspace
|
||||
|
||||
tokenizer.decoder = Metaspace(replacement="▁", add_prefix_space=True)
|
||||
|
||||
# Converts ▁ back to spaces
|
||||
# ["▁Hello", "▁world"] → "Hello world"
|
||||
```
|
||||
|
||||
### BPEDecoder
|
||||
|
||||
```python
|
||||
from tokenizers.decoders import BPEDecoder
|
||||
|
||||
tokenizer.decoder = BPEDecoder(suffix="</w>")
|
||||
|
||||
# Removes suffix and concatenates
|
||||
# ["token", "ization</w>"] → "tokenization"
|
||||
```
|
||||
|
||||
### Sequence decoder
|
||||
|
||||
```python
|
||||
from tokenizers.decoders import Sequence, ByteLevel, Strip
|
||||
|
||||
tokenizer.decoder = Sequence([
|
||||
ByteLevel(), # Decode byte-level first
|
||||
Strip(' ', 1, 1) # Strip leading/trailing spaces
|
||||
])
|
||||
```
|
||||
|
||||
## Complete pipeline examples
|
||||
|
||||
### BERT tokenizer
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import WordPiece
|
||||
from tokenizers.normalizers import BertNormalizer
|
||||
from tokenizers.pre_tokenizers import BertPreTokenizer
|
||||
from tokenizers.processors import TemplateProcessing
|
||||
from tokenizers.decoders import WordPiece as WordPieceDecoder
|
||||
|
||||
# Model
|
||||
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
|
||||
|
||||
# Normalization
|
||||
tokenizer.normalizer = BertNormalizer(lowercase=True)
|
||||
|
||||
# Pre-tokenization
|
||||
tokenizer.pre_tokenizer = BertPreTokenizer()
|
||||
|
||||
# Post-processing
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="[CLS] $A [SEP]",
|
||||
pair="[CLS] $A [SEP] $B [SEP]",
|
||||
special_tokens=[("[CLS]", 101), ("[SEP]", 102)],
|
||||
)
|
||||
|
||||
# Decoder
|
||||
tokenizer.decoder = WordPieceDecoder(prefix="##")
|
||||
|
||||
# Enable padding
|
||||
tokenizer.enable_padding(pad_id=0, pad_token="[PAD]")
|
||||
|
||||
# Enable truncation
|
||||
tokenizer.enable_truncation(max_length=512)
|
||||
```
|
||||
|
||||
### GPT-2 tokenizer
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.normalizers import NFC
|
||||
from tokenizers.pre_tokenizers import ByteLevel
|
||||
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
|
||||
from tokenizers.processors import TemplateProcessing
|
||||
|
||||
# Model
|
||||
tokenizer = Tokenizer(BPE())
|
||||
|
||||
# Normalization (minimal)
|
||||
tokenizer.normalizer = NFC()
|
||||
|
||||
# Byte-level pre-tokenization
|
||||
tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=False)
|
||||
|
||||
# Post-processing
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="$A <|endoftext|>",
|
||||
special_tokens=[("<|endoftext|>", 50256)],
|
||||
)
|
||||
|
||||
# Byte-level decoder
|
||||
tokenizer.decoder = ByteLevelDecoder()
|
||||
```
|
||||
|
||||
### T5 tokenizer (SentencePiece-style)
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import Unigram
|
||||
from tokenizers.normalizers import NFKC
|
||||
from tokenizers.pre_tokenizers import Metaspace
|
||||
from tokenizers.decoders import Metaspace as MetaspaceDecoder
|
||||
|
||||
# Model
|
||||
tokenizer = Tokenizer(Unigram())
|
||||
|
||||
# Normalization
|
||||
tokenizer.normalizer = NFKC()
|
||||
|
||||
# Metaspace pre-tokenization
|
||||
tokenizer.pre_tokenizer = Metaspace(replacement="▁", add_prefix_space=True)
|
||||
|
||||
# No post-processing (T5 doesn't add CLS/SEP)
|
||||
tokenizer.post_processor = None
|
||||
|
||||
# Metaspace decoder
|
||||
tokenizer.decoder = MetaspaceDecoder(replacement="▁", add_prefix_space=True)
|
||||
```
|
||||
|
||||
## Alignment tracking
|
||||
|
||||
Track token positions in original text.
|
||||
|
||||
### Basic alignment
|
||||
|
||||
```python
|
||||
text = "Hello, world!"
|
||||
output = tokenizer.encode(text)
|
||||
|
||||
for token, (start, end) in zip(output.tokens, output.offsets):
|
||||
print(f"{token:10s} → [{start:2d}, {end:2d}): {text[start:end]!r}")
|
||||
|
||||
# Output:
|
||||
# [CLS] → [ 0, 0): ''
|
||||
# hello → [ 0, 5): 'Hello'
|
||||
# , → [ 5, 6): ','
|
||||
# world → [ 7, 12): 'world'
|
||||
# ! → [12, 13): '!'
|
||||
# [SEP] → [ 0, 0): ''
|
||||
```
|
||||
|
||||
### Word-level alignment
|
||||
|
||||
```python
|
||||
# Get word_ids (which word each token belongs to)
|
||||
encoding = tokenizer.encode("Hello world")
|
||||
word_ids = encoding.word_ids
|
||||
|
||||
print(word_ids)
|
||||
# [None, 0, 0, 1, None]
|
||||
# None = special token, 0 = first word, 1 = second word
|
||||
```
|
||||
|
||||
**Use case**: Token classification (NER)
|
||||
```python
|
||||
# Align predictions to words
|
||||
predictions = ["O", "B-PER", "I-PER", "O", "O"]
|
||||
word_predictions = {}
|
||||
|
||||
for token_idx, word_idx in enumerate(encoding.word_ids):
|
||||
if word_idx is not None and word_idx not in word_predictions:
|
||||
word_predictions[word_idx] = predictions[token_idx]
|
||||
|
||||
print(word_predictions)
|
||||
# {0: "B-PER", 1: "O"} # First word is PERSON, second is OTHER
|
||||
```
|
||||
|
||||
### Span alignment
|
||||
|
||||
```python
|
||||
# Find token span for character span
|
||||
text = "Machine learning is awesome"
|
||||
char_start, char_end = 8, 16 # "learning"
|
||||
|
||||
encoding = tokenizer.encode(text)
|
||||
|
||||
# Find token span
|
||||
token_start = encoding.char_to_token(char_start)
|
||||
token_end = encoding.char_to_token(char_end - 1) + 1
|
||||
|
||||
print(f"Tokens {token_start}:{token_end} = {encoding.tokens[token_start:token_end]}")
|
||||
# Tokens 2:3 = ['learning']
|
||||
```
|
||||
|
||||
**Use case**: Question answering (extract answer span)
|
||||
|
||||
## Custom components
|
||||
|
||||
### Custom normalizer
|
||||
|
||||
```python
|
||||
from tokenizers import NormalizedString, Normalizer
|
||||
|
||||
class CustomNormalizer:
|
||||
def normalize(self, normalized: NormalizedString):
|
||||
# Custom normalization logic
|
||||
normalized.lowercase()
|
||||
normalized.replace(" ", " ") # Replace double spaces
|
||||
|
||||
# Use custom normalizer
|
||||
tokenizer.normalizer = CustomNormalizer()
|
||||
```
|
||||
|
||||
### Custom pre-tokenizer
|
||||
|
||||
```python
|
||||
from tokenizers import PreTokenizedString
|
||||
|
||||
class CustomPreTokenizer:
|
||||
def pre_tokenize(self, pretok: PreTokenizedString):
|
||||
# Custom pre-tokenization logic
|
||||
pretok.split(lambda i, char: char.isspace())
|
||||
|
||||
tokenizer.pre_tokenizer = CustomPreTokenizer()
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Issue: Misaligned offsets
|
||||
|
||||
**Symptom**: Offsets don't match original text
|
||||
```python
|
||||
text = " hello" # Leading spaces
|
||||
offsets = [(0, 5)] # Expects " hel"
|
||||
```
|
||||
|
||||
**Solution**: Check normalization strips spaces
|
||||
```python
|
||||
# Preserve offsets
|
||||
tokenizer.normalizer = Sequence([
|
||||
Strip(), # This changes offsets!
|
||||
])
|
||||
|
||||
# Use trim_offsets in post-processor instead
|
||||
tokenizer.post_processor = ByteLevelProcessing(trim_offsets=True)
|
||||
```
|
||||
|
||||
### Issue: Special tokens not added
|
||||
|
||||
**Symptom**: No [CLS] or [SEP] in output
|
||||
|
||||
**Solution**: Check post-processor is set
|
||||
```python
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="[CLS] $A [SEP]",
|
||||
special_tokens=[("[CLS]", 101), ("[SEP]", 102)],
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Incorrect decoding
|
||||
|
||||
**Symptom**: Decoded text has ## or ▁
|
||||
|
||||
**Solution**: Set correct decoder
|
||||
```python
|
||||
# For WordPiece
|
||||
tokenizer.decoder = WordPieceDecoder(prefix="##")
|
||||
|
||||
# For SentencePiece
|
||||
tokenizer.decoder = MetaspaceDecoder(replacement="▁")
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Match pipeline to model architecture**:
|
||||
- BERT → BertNormalizer + BertPreTokenizer + WordPiece
|
||||
- GPT-2 → NFC + ByteLevel + BPE
|
||||
- T5 → NFKC + Metaspace + Unigram
|
||||
|
||||
2. **Test pipeline on sample inputs**:
|
||||
- Check normalization doesn't over-normalize
|
||||
- Verify pre-tokenization splits correctly
|
||||
- Ensure decoding reconstructs text
|
||||
|
||||
3. **Preserve alignment for downstream tasks**:
|
||||
- Use `trim_offsets` instead of stripping in normalizer
|
||||
- Test `char_to_token()` on sample spans
|
||||
|
||||
4. **Document your pipeline**:
|
||||
- Save complete tokenizer config
|
||||
- Document special tokens
|
||||
- Note any custom components
|
||||
@@ -0,0 +1,565 @@
|
||||
# Training Custom Tokenizers
|
||||
|
||||
Complete guide to training tokenizers from scratch.
|
||||
|
||||
## Training workflow
|
||||
|
||||
### Step 1: Choose tokenization algorithm
|
||||
|
||||
**Decision tree**:
|
||||
- **GPT-style model** → BPE
|
||||
- **BERT-style model** → WordPiece
|
||||
- **Multilingual/No word boundaries** → Unigram
|
||||
|
||||
### Step 2: Prepare training data
|
||||
|
||||
```python
|
||||
# Option 1: From files
|
||||
files = ["train.txt", "validation.txt"]
|
||||
|
||||
# Option 2: From Python list
|
||||
texts = [
|
||||
"This is the first sentence.",
|
||||
"This is the second sentence.",
|
||||
# ... more texts
|
||||
]
|
||||
|
||||
# Option 3: From dataset iterator
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="train")
|
||||
|
||||
def batch_iterator(batch_size=1000):
|
||||
for i in range(0, len(dataset), batch_size):
|
||||
yield dataset[i:i + batch_size]["text"]
|
||||
```
|
||||
|
||||
### Step 3: Initialize tokenizer
|
||||
|
||||
**BPE example**:
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
from tokenizers.pre_tokenizers import ByteLevel
|
||||
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
|
||||
|
||||
tokenizer = Tokenizer(BPE())
|
||||
tokenizer.pre_tokenizer = ByteLevel()
|
||||
tokenizer.decoder = ByteLevelDecoder()
|
||||
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=50000,
|
||||
min_frequency=2,
|
||||
special_tokens=["<|endoftext|>", "<|padding|>"],
|
||||
show_progress=True
|
||||
)
|
||||
```
|
||||
|
||||
**WordPiece example**:
|
||||
```python
|
||||
from tokenizers.models import WordPiece
|
||||
from tokenizers.trainers import WordPieceTrainer
|
||||
from tokenizers.normalizers import BertNormalizer
|
||||
from tokenizers.pre_tokenizers import BertPreTokenizer
|
||||
|
||||
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
|
||||
tokenizer.normalizer = BertNormalizer(lowercase=True)
|
||||
tokenizer.pre_tokenizer = BertPreTokenizer()
|
||||
|
||||
trainer = WordPieceTrainer(
|
||||
vocab_size=30522,
|
||||
min_frequency=2,
|
||||
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
|
||||
continuing_subword_prefix="##",
|
||||
show_progress=True
|
||||
)
|
||||
```
|
||||
|
||||
**Unigram example**:
|
||||
```python
|
||||
from tokenizers.models import Unigram
|
||||
from tokenizers.trainers import UnigramTrainer
|
||||
|
||||
tokenizer = Tokenizer(Unigram())
|
||||
|
||||
trainer = UnigramTrainer(
|
||||
vocab_size=8000,
|
||||
special_tokens=["<unk>", "<s>", "</s>", "<pad>"],
|
||||
unk_token="<unk>",
|
||||
show_progress=True
|
||||
)
|
||||
```
|
||||
|
||||
### Step 4: Train
|
||||
|
||||
```python
|
||||
# From files
|
||||
tokenizer.train(files=files, trainer=trainer)
|
||||
|
||||
# From iterator (recommended for large datasets)
|
||||
tokenizer.train_from_iterator(
|
||||
batch_iterator(),
|
||||
trainer=trainer,
|
||||
length=len(dataset) # Optional, for progress bar
|
||||
)
|
||||
```
|
||||
|
||||
**Training time** (30k vocab on 16-core CPU):
|
||||
- 10 MB: 15-30 seconds
|
||||
- 100 MB: 1-3 minutes
|
||||
- 1 GB: 15-30 minutes
|
||||
- 10 GB: 2-4 hours
|
||||
|
||||
### Step 5: Add post-processing
|
||||
|
||||
```python
|
||||
from tokenizers.processors import TemplateProcessing
|
||||
|
||||
# BERT-style
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="[CLS] $A [SEP]",
|
||||
pair="[CLS] $A [SEP] $B [SEP]",
|
||||
special_tokens=[
|
||||
("[CLS]", tokenizer.token_to_id("[CLS]")),
|
||||
("[SEP]", tokenizer.token_to_id("[SEP]")),
|
||||
],
|
||||
)
|
||||
|
||||
# GPT-2 style
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="$A <|endoftext|>",
|
||||
special_tokens=[
|
||||
("<|endoftext|>", tokenizer.token_to_id("<|endoftext|>")),
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
### Step 6: Save
|
||||
|
||||
```python
|
||||
# Save to JSON
|
||||
tokenizer.save("my-tokenizer.json")
|
||||
|
||||
# Save to directory (for transformers)
|
||||
tokenizer.save("my-tokenizer-dir/tokenizer.json")
|
||||
|
||||
# Convert to transformers format
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
transformers_tokenizer = PreTrainedTokenizerFast(
|
||||
tokenizer_object=tokenizer,
|
||||
unk_token="[UNK]",
|
||||
pad_token="[PAD]",
|
||||
cls_token="[CLS]",
|
||||
sep_token="[SEP]",
|
||||
mask_token="[MASK]"
|
||||
)
|
||||
|
||||
transformers_tokenizer.save_pretrained("my-tokenizer-dir")
|
||||
```
|
||||
|
||||
## Trainer configuration
|
||||
|
||||
### BpeTrainer parameters
|
||||
|
||||
```python
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=30000, # Target vocabulary size
|
||||
min_frequency=2, # Minimum frequency for merges
|
||||
special_tokens=["[UNK]"], # Special tokens (added first)
|
||||
limit_alphabet=1000, # Limit initial alphabet size
|
||||
initial_alphabet=[], # Pre-defined initial characters
|
||||
show_progress=True, # Show progress bar
|
||||
continuing_subword_prefix="", # Prefix for continuing subwords
|
||||
end_of_word_suffix="" # Suffix for end of words
|
||||
)
|
||||
```
|
||||
|
||||
**Parameter tuning**:
|
||||
- **vocab_size**: Start with 30k for English, 50k for multilingual
|
||||
- **min_frequency**: 2-5 for large corpora, 1 for small
|
||||
- **limit_alphabet**: Reduce for non-English (CJK languages)
|
||||
|
||||
### WordPieceTrainer parameters
|
||||
|
||||
```python
|
||||
from tokenizers.trainers import WordPieceTrainer
|
||||
|
||||
trainer = WordPieceTrainer(
|
||||
vocab_size=30522, # BERT uses 30,522
|
||||
min_frequency=2,
|
||||
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
|
||||
limit_alphabet=1000,
|
||||
continuing_subword_prefix="##", # BERT-style prefix
|
||||
show_progress=True
|
||||
)
|
||||
```
|
||||
|
||||
### UnigramTrainer parameters
|
||||
|
||||
```python
|
||||
from tokenizers.trainers import UnigramTrainer
|
||||
|
||||
trainer = UnigramTrainer(
|
||||
vocab_size=8000, # Typically smaller than BPE/WordPiece
|
||||
special_tokens=["<unk>", "<s>", "</s>"],
|
||||
unk_token="<unk>",
|
||||
max_piece_length=16, # Maximum token length
|
||||
n_sub_iterations=2, # EM algorithm iterations
|
||||
shrinking_factor=0.75, # Vocabulary reduction rate
|
||||
show_progress=True
|
||||
)
|
||||
```
|
||||
|
||||
## Training from large datasets
|
||||
|
||||
### Memory-efficient training
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
|
||||
# Load dataset
|
||||
dataset = load_dataset("wikipedia", "20220301.en", split="train", streaming=True)
|
||||
|
||||
# Create iterator (yields batches)
|
||||
def batch_iterator(batch_size=1000):
|
||||
batch = []
|
||||
for sample in dataset:
|
||||
batch.append(sample["text"])
|
||||
if len(batch) >= batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
# Initialize tokenizer
|
||||
tokenizer = Tokenizer(BPE())
|
||||
trainer = BpeTrainer(vocab_size=50000, special_tokens=["<|endoftext|>"])
|
||||
|
||||
# Train (memory efficient - streams data)
|
||||
tokenizer.train_from_iterator(
|
||||
batch_iterator(),
|
||||
trainer=trainer
|
||||
)
|
||||
```
|
||||
|
||||
**Memory usage**: ~200 MB (vs 10+ GB loading full dataset)
|
||||
|
||||
### Multi-file training
|
||||
|
||||
```python
|
||||
import glob
|
||||
|
||||
# Find all training files
|
||||
files = glob.glob("data/train/*.txt")
|
||||
print(f"Training on {len(files)} files")
|
||||
|
||||
# Train on all files
|
||||
tokenizer.train(files=files, trainer=trainer)
|
||||
```
|
||||
|
||||
### Parallel training (multi-processing)
|
||||
|
||||
```python
|
||||
from multiprocessing import Pool, cpu_count
|
||||
import os
|
||||
|
||||
def train_shard(shard_files):
|
||||
"""Train tokenizer on a shard of files."""
|
||||
tokenizer = Tokenizer(BPE())
|
||||
trainer = BpeTrainer(vocab_size=50000)
|
||||
tokenizer.train(files=shard_files, trainer=trainer)
|
||||
return tokenizer.get_vocab()
|
||||
|
||||
# Split files into shards
|
||||
num_shards = cpu_count()
|
||||
file_shards = [files[i::num_shards] for i in range(num_shards)]
|
||||
|
||||
# Train shards in parallel
|
||||
with Pool(num_shards) as pool:
|
||||
vocab_shards = pool.map(train_shard, file_shards)
|
||||
|
||||
# Merge vocabularies (custom logic needed)
|
||||
# This is a simplified example - real implementation would merge intelligently
|
||||
final_vocab = {}
|
||||
for vocab in vocab_shards:
|
||||
final_vocab.update(vocab)
|
||||
```
|
||||
|
||||
## Domain-specific tokenizers
|
||||
|
||||
### Code tokenizer
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
from tokenizers.pre_tokenizers import ByteLevel
|
||||
from tokenizers.normalizers import Sequence, NFC
|
||||
|
||||
# Code-optimized configuration
|
||||
tokenizer = Tokenizer(BPE())
|
||||
|
||||
# Minimal normalization (preserve case, whitespace)
|
||||
tokenizer.normalizer = NFC() # Only normalize Unicode
|
||||
|
||||
# Byte-level pre-tokenization (handles all characters)
|
||||
tokenizer.pre_tokenizer = ByteLevel()
|
||||
|
||||
# Train on code corpus
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=50000,
|
||||
special_tokens=["<|endoftext|>", "<|pad|>"],
|
||||
min_frequency=2
|
||||
)
|
||||
|
||||
tokenizer.train(files=["code_corpus.txt"], trainer=trainer)
|
||||
```
|
||||
|
||||
### Medical/scientific tokenizer
|
||||
|
||||
```python
|
||||
# Preserve case and special characters
|
||||
from tokenizers.normalizers import NFKC
|
||||
from tokenizers.pre_tokenizers import Whitespace, Punctuation, Sequence
|
||||
|
||||
tokenizer = Tokenizer(BPE())
|
||||
|
||||
# Minimal normalization
|
||||
tokenizer.normalizer = NFKC()
|
||||
|
||||
# Preserve medical terms
|
||||
tokenizer.pre_tokenizer = Sequence([
|
||||
Whitespace(),
|
||||
Punctuation(behavior="isolated") # Keep punctuation separate
|
||||
])
|
||||
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=50000,
|
||||
special_tokens=["[UNK]", "[CLS]", "[SEP]"],
|
||||
min_frequency=3 # Higher threshold for rare medical terms
|
||||
)
|
||||
|
||||
tokenizer.train(files=["pubmed_corpus.txt"], trainer=trainer)
|
||||
```
|
||||
|
||||
### Multilingual tokenizer
|
||||
|
||||
```python
|
||||
# Handle multiple scripts
|
||||
from tokenizers.normalizers import NFKC, Lowercase, Sequence
|
||||
|
||||
tokenizer = Tokenizer(BPE())
|
||||
|
||||
# Normalize but don't lowercase (preserves script differences)
|
||||
tokenizer.normalizer = NFKC()
|
||||
|
||||
# Byte-level handles all Unicode
|
||||
from tokenizers.pre_tokenizers import ByteLevel
|
||||
tokenizer.pre_tokenizer = ByteLevel()
|
||||
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=100000, # Larger vocab for multiple languages
|
||||
special_tokens=["<unk>", "<s>", "</s>"],
|
||||
limit_alphabet=None # No limit (handles all scripts)
|
||||
)
|
||||
|
||||
# Train on multilingual corpus
|
||||
tokenizer.train(files=["multilingual_corpus.txt"], trainer=trainer)
|
||||
```
|
||||
|
||||
## Vocabulary size selection
|
||||
|
||||
### Guidelines by task
|
||||
|
||||
| Task | Recommended Vocab Size | Rationale |
|
||||
|-----------------------|------------------------|-----------|
|
||||
| English (monolingual) | 30,000 - 50,000 | Balanced coverage |
|
||||
| Multilingual | 50,000 - 250,000 | More languages = more tokens |
|
||||
| Code | 30,000 - 50,000 | Similar to English |
|
||||
| Domain-specific | 10,000 - 30,000 | Smaller, focused vocabulary |
|
||||
| Character-level tasks | 1,000 - 5,000 | Only characters + subwords |
|
||||
|
||||
### Vocabulary size impact
|
||||
|
||||
**Small vocab (10k)**:
|
||||
- Pros: Faster training, smaller model, less memory
|
||||
- Cons: More tokens per sentence, worse OOV handling
|
||||
|
||||
**Medium vocab (30k-50k)**:
|
||||
- Pros: Good balance, standard choice
|
||||
- Cons: None (recommended default)
|
||||
|
||||
**Large vocab (100k+)**:
|
||||
- Pros: Fewer tokens per sentence, better OOV
|
||||
- Cons: Slower training, larger embedding table
|
||||
|
||||
### Empirical testing
|
||||
|
||||
```python
|
||||
# Train multiple tokenizers with different vocab sizes
|
||||
vocab_sizes = [10000, 30000, 50000, 100000]
|
||||
|
||||
for vocab_size in vocab_sizes:
|
||||
tokenizer = Tokenizer(BPE())
|
||||
trainer = BpeTrainer(vocab_size=vocab_size)
|
||||
tokenizer.train(files=["sample.txt"], trainer=trainer)
|
||||
|
||||
# Evaluate on test set
|
||||
test_text = "Test sentence for evaluation..."
|
||||
tokens = tokenizer.encode(test_text).ids
|
||||
|
||||
print(f"Vocab: {vocab_size:6d} | Tokens: {len(tokens):3d} | Avg: {len(test_text)/len(tokens):.2f} chars/token")
|
||||
|
||||
# Example output:
|
||||
# Vocab: 10000 | Tokens: 12 | Avg: 2.33 chars/token
|
||||
# Vocab: 30000 | Tokens: 8 | Avg: 3.50 chars/token
|
||||
# Vocab: 50000 | Tokens: 7 | Avg: 4.00 chars/token
|
||||
# Vocab: 100000 | Tokens: 6 | Avg: 4.67 chars/token
|
||||
```
|
||||
|
||||
## Testing tokenizer quality
|
||||
|
||||
### Coverage test
|
||||
|
||||
```python
|
||||
# Test on held-out data
|
||||
test_corpus = load_dataset("wikitext", "wikitext-103-raw-v1", split="test")
|
||||
|
||||
total_tokens = 0
|
||||
unk_tokens = 0
|
||||
unk_id = tokenizer.token_to_id("[UNK]")
|
||||
|
||||
for text in test_corpus["text"]:
|
||||
if text.strip():
|
||||
encoding = tokenizer.encode(text)
|
||||
total_tokens += len(encoding.ids)
|
||||
unk_tokens += encoding.ids.count(unk_id)
|
||||
|
||||
unk_rate = unk_tokens / total_tokens
|
||||
print(f"Unknown token rate: {unk_rate:.2%}")
|
||||
|
||||
# Good quality: <1% unknown tokens
|
||||
# Acceptable: 1-5%
|
||||
# Poor: >5%
|
||||
```
|
||||
|
||||
### Compression test
|
||||
|
||||
```python
|
||||
# Measure tokenization efficiency
|
||||
import numpy as np
|
||||
|
||||
token_lengths = []
|
||||
|
||||
for text in test_corpus["text"][:1000]:
|
||||
if text.strip():
|
||||
encoding = tokenizer.encode(text)
|
||||
chars_per_token = len(text) / len(encoding.ids)
|
||||
token_lengths.append(chars_per_token)
|
||||
|
||||
avg_chars_per_token = np.mean(token_lengths)
|
||||
print(f"Average characters per token: {avg_chars_per_token:.2f}")
|
||||
|
||||
# Good: 4-6 chars/token (English)
|
||||
# Acceptable: 3-4 chars/token
|
||||
# Poor: <3 chars/token (under-compression)
|
||||
```
|
||||
|
||||
### Semantic test
|
||||
|
||||
```python
|
||||
# Manually inspect tokenization of common words/phrases
|
||||
test_phrases = [
|
||||
"tokenization",
|
||||
"machine learning",
|
||||
"artificial intelligence",
|
||||
"preprocessing",
|
||||
"hello world"
|
||||
]
|
||||
|
||||
for phrase in test_phrases:
|
||||
tokens = tokenizer.encode(phrase).tokens
|
||||
print(f"{phrase:25s} → {tokens}")
|
||||
|
||||
# Good tokenization:
|
||||
# tokenization → ['token', 'ization']
|
||||
# machine learning → ['machine', 'learning']
|
||||
# artificial intelligence → ['artificial', 'intelligence']
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Issue: Training too slow
|
||||
|
||||
**Solutions**:
|
||||
1. Reduce vocabulary size
|
||||
2. Increase `min_frequency`
|
||||
3. Use `limit_alphabet` to reduce initial alphabet
|
||||
4. Train on subset first
|
||||
|
||||
```python
|
||||
# Fast training configuration
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=20000, # Smaller vocab
|
||||
min_frequency=5, # Higher threshold
|
||||
limit_alphabet=500, # Limit alphabet
|
||||
show_progress=True
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: High unknown token rate
|
||||
|
||||
**Solutions**:
|
||||
1. Increase vocabulary size
|
||||
2. Decrease `min_frequency`
|
||||
3. Check normalization (might be too aggressive)
|
||||
|
||||
```python
|
||||
# Better coverage configuration
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=50000, # Larger vocab
|
||||
min_frequency=1, # Lower threshold
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Poor quality tokenization
|
||||
|
||||
**Solutions**:
|
||||
1. Verify normalization matches your use case
|
||||
2. Check pre-tokenization splits correctly
|
||||
3. Ensure training data is representative
|
||||
4. Try different algorithm (BPE vs WordPiece vs Unigram)
|
||||
|
||||
```python
|
||||
# Debug tokenization pipeline
|
||||
text = "Sample text to debug"
|
||||
|
||||
# Check normalization
|
||||
normalized = tokenizer.normalizer.normalize_str(text)
|
||||
print(f"Normalized: {normalized}")
|
||||
|
||||
# Check pre-tokenization
|
||||
pre_tokens = tokenizer.pre_tokenizer.pre_tokenize_str(text)
|
||||
print(f"Pre-tokens: {pre_tokens}")
|
||||
|
||||
# Check final tokenization
|
||||
tokens = tokenizer.encode(text).tokens
|
||||
print(f"Tokens: {tokens}")
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Use representative training data** - Match your target domain
|
||||
2. **Start with standard configs** - BERT WordPiece or GPT-2 BPE
|
||||
3. **Test on held-out data** - Measure unknown token rate
|
||||
4. **Iterate on vocabulary size** - Test 30k, 50k, 100k
|
||||
5. **Save tokenizer with model** - Ensure reproducibility
|
||||
6. **Version your tokenizers** - Track changes for reproducibility
|
||||
7. **Document special tokens** - Critical for model training
|
||||
@@ -0,0 +1,235 @@
|
||||
---
|
||||
name: sentencepiece
|
||||
description: Language-independent tokenizer treating text as raw Unicode. Supports BPE and Unigram algorithms. Fast (50k sentences/sec), lightweight (6MB memory), deterministic vocabulary. Used by T5, ALBERT, XLNet, mBART. Train on raw text without pre-tokenization. Use when you need multilingual support, CJK languages, or reproducible tokenization.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Tokenization, SentencePiece, Language-Independent, BPE, Unigram, Multilingual, CJK Languages, Unicode, Deterministic, Google]
|
||||
dependencies: [sentencepiece, transformers]
|
||||
---
|
||||
|
||||
# SentencePiece - Language-Independent Tokenization
|
||||
|
||||
Unsupervised tokenizer that works on raw text without language-specific preprocessing.
|
||||
|
||||
## When to use SentencePiece
|
||||
|
||||
**Use SentencePiece when:**
|
||||
- Building multilingual models (no language-specific rules)
|
||||
- Working with CJK languages (Chinese, Japanese, Korean)
|
||||
- Need reproducible tokenization (deterministic vocabulary)
|
||||
- Want to train on raw text (no pre-tokenization needed)
|
||||
- Require lightweight deployment (6MB memory, 50k sentences/sec)
|
||||
|
||||
**Performance**:
|
||||
- **Speed**: 50,000 sentences/sec
|
||||
- **Memory**: ~6MB for loaded model
|
||||
- **Languages**: All (language-independent)
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **HuggingFace Tokenizers**: Faster training, more flexibility
|
||||
- **tiktoken**: OpenAI models (GPT-3.5/4)
|
||||
- **BERT WordPiece**: English-centric tasks
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Python
|
||||
pip install sentencepiece
|
||||
|
||||
# C++ (requires CMake)
|
||||
git clone https://github.com/google/sentencepiece.git
|
||||
cd sentencepiece
|
||||
mkdir build && cd build
|
||||
cmake .. && make -j $(nproc)
|
||||
sudo make install
|
||||
```
|
||||
|
||||
### Train model
|
||||
|
||||
```bash
|
||||
# Command-line (BPE with 8000 vocab)
|
||||
spm_train --input=data.txt --model_prefix=m --vocab_size=8000 --model_type=bpe
|
||||
|
||||
# Python API
|
||||
import sentencepiece as spm
|
||||
|
||||
spm.SentencePieceTrainer.train(
|
||||
input='data.txt',
|
||||
model_prefix='m',
|
||||
vocab_size=8000,
|
||||
model_type='bpe'
|
||||
)
|
||||
```
|
||||
|
||||
**Training time**: ~1-2 minutes for 100MB corpus
|
||||
|
||||
### Encode and decode
|
||||
|
||||
```python
|
||||
import sentencepiece as spm
|
||||
|
||||
# Load model
|
||||
sp = spm.SentencePieceProcessor(model_file='m.model')
|
||||
|
||||
# Encode to pieces
|
||||
pieces = sp.encode('This is a test', out_type=str)
|
||||
print(pieces) # ['▁This', '▁is', '▁a', '▁test']
|
||||
|
||||
# Encode to IDs
|
||||
ids = sp.encode('This is a test', out_type=int)
|
||||
print(ids) # [284, 47, 11, 1243]
|
||||
|
||||
# Decode
|
||||
text = sp.decode(ids)
|
||||
print(text) # "This is a test"
|
||||
```
|
||||
|
||||
## Language-independent design
|
||||
|
||||
### Whitespace as symbol (▁)
|
||||
|
||||
```python
|
||||
text = "Hello world"
|
||||
pieces = sp.encode(text, out_type=str)
|
||||
print(pieces) # ['▁Hello', '▁world']
|
||||
|
||||
# Decode preserves spaces
|
||||
decoded = sp.decode_pieces(pieces)
|
||||
print(decoded) # "Hello world"
|
||||
```
|
||||
|
||||
**Key principle**: Treat text as raw Unicode, whitespace = ▁ (meta symbol)
|
||||
|
||||
## Tokenization algorithms
|
||||
|
||||
### BPE (Byte-Pair Encoding)
|
||||
|
||||
```python
|
||||
spm.SentencePieceTrainer.train(
|
||||
input='data.txt',
|
||||
model_prefix='bpe_model',
|
||||
vocab_size=16000,
|
||||
model_type='bpe'
|
||||
)
|
||||
```
|
||||
|
||||
**Used by**: mBART
|
||||
|
||||
### Unigram (default)
|
||||
|
||||
```python
|
||||
spm.SentencePieceTrainer.train(
|
||||
input='data.txt',
|
||||
model_prefix='unigram_model',
|
||||
vocab_size=8000,
|
||||
model_type='unigram'
|
||||
)
|
||||
```
|
||||
|
||||
**Used by**: T5, ALBERT, XLNet
|
||||
|
||||
## Training configuration
|
||||
|
||||
### Essential parameters
|
||||
|
||||
```python
|
||||
spm.SentencePieceTrainer.train(
|
||||
input='corpus.txt',
|
||||
model_prefix='m',
|
||||
vocab_size=32000,
|
||||
model_type='unigram',
|
||||
character_coverage=0.9995, # 1.0 for CJK
|
||||
user_defined_symbols=['[SEP]', '[CLS]'],
|
||||
unk_piece='<unk>',
|
||||
num_threads=16
|
||||
)
|
||||
```
|
||||
|
||||
### Character coverage
|
||||
|
||||
| Language Type | Coverage | Rationale |
|
||||
|---------------|----------|-----------|
|
||||
| English | 0.9995 | Most common chars |
|
||||
| CJK (Chinese) | 1.0 | All characters needed |
|
||||
| Multilingual | 0.9995 | Balance |
|
||||
|
||||
## Encoding options
|
||||
|
||||
### Subword regularization
|
||||
|
||||
```python
|
||||
# Sample different tokenizations
|
||||
for _ in range(3):
|
||||
pieces = sp.encode('tokenization', out_type=str, enable_sampling=True, alpha=0.1)
|
||||
print(pieces)
|
||||
|
||||
# Output (different each time):
|
||||
# ['▁token', 'ization']
|
||||
# ['▁tok', 'en', 'ization']
|
||||
```
|
||||
|
||||
**Use case**: Data augmentation for robustness.
|
||||
|
||||
## Common patterns
|
||||
|
||||
### T5-style training
|
||||
|
||||
```python
|
||||
spm.SentencePieceTrainer.train(
|
||||
input='c4_corpus.txt',
|
||||
model_prefix='t5',
|
||||
vocab_size=32000,
|
||||
model_type='unigram',
|
||||
user_defined_symbols=[f'<extra_id_{i}>' for i in range(100)],
|
||||
unk_id=2,
|
||||
eos_id=1,
|
||||
pad_id=0
|
||||
)
|
||||
```
|
||||
|
||||
### Integration with transformers
|
||||
|
||||
```python
|
||||
from transformers import T5Tokenizer
|
||||
|
||||
# T5 uses SentencePiece internally
|
||||
tokenizer = T5Tokenizer.from_pretrained('t5-base')
|
||||
inputs = tokenizer('translate English to French: Hello', return_tensors='pt')
|
||||
```
|
||||
|
||||
## Performance benchmarks
|
||||
|
||||
### Training speed
|
||||
|
||||
| Corpus | BPE (16k) | Unigram (8k) |
|
||||
|--------|-----------|--------------|
|
||||
| 100 MB | 1-2 min | 3-4 min |
|
||||
| 1 GB | 10-15 min | 30-40 min |
|
||||
|
||||
### Tokenization speed
|
||||
|
||||
- **SentencePiece**: 50,000 sentences/sec
|
||||
- **HF Tokenizers**: 200,000 sentences/sec (4× faster)
|
||||
|
||||
## Supported models
|
||||
|
||||
**T5 family**: `t5-base`, `t5-large` (32k vocab, Unigram)
|
||||
**ALBERT**: `albert-base-v2` (30k vocab, Unigram)
|
||||
**XLNet**: `xlnet-base-cased` (32k vocab, Unigram)
|
||||
**mBART**: `facebook/mbart-large-50` (250k vocab, BPE)
|
||||
|
||||
## References
|
||||
|
||||
- **[Training Guide](references/training.md)** - Detailed options, corpus preparation
|
||||
- **[Algorithms](references/algorithms.md)** - BPE vs Unigram, subword regularization
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/google/sentencepiece ⭐ 10,000+
|
||||
- **Paper**: https://arxiv.org/abs/1808.06226 (EMNLP 2018)
|
||||
- **Version**: 0.2.0+
|
||||
|
||||
|
||||
@@ -0,0 +1,200 @@
|
||||
# Tokenization Algorithms
|
||||
|
||||
BPE vs Unigram comparison and subword regularization.
|
||||
|
||||
## BPE (Byte-Pair Encoding)
|
||||
|
||||
### Algorithm
|
||||
|
||||
1. Initialize vocabulary with characters
|
||||
2. Count frequency of adjacent token pairs
|
||||
3. Merge most frequent pair
|
||||
4. Repeat until vocabulary size reached
|
||||
|
||||
### Example
|
||||
|
||||
**Corpus**:
|
||||
```
|
||||
low: 5
|
||||
lower: 2
|
||||
newest: 6
|
||||
widest: 3
|
||||
```
|
||||
|
||||
**Iteration 1**:
|
||||
- Most frequent pair: 'e' + 's' (9 times)
|
||||
- Merge → 'es'
|
||||
- Vocabulary: [chars] + ['es']
|
||||
|
||||
**Iteration 2**:
|
||||
- Most frequent: 'es' + 't' (9 times)
|
||||
- Merge → 'est'
|
||||
- Vocabulary: [chars] + ['es', 'est']
|
||||
|
||||
**Result**: `newest` → `new|est`, `widest` → `wid|est`
|
||||
|
||||
### Implementation
|
||||
|
||||
```python
|
||||
import sentencepiece as spm
|
||||
|
||||
spm.SentencePieceTrainer.train(
|
||||
input='corpus.txt',
|
||||
model_type='bpe',
|
||||
vocab_size=16000
|
||||
)
|
||||
```
|
||||
|
||||
### Advantages
|
||||
|
||||
- Simple algorithm
|
||||
- Fast training
|
||||
- Good compression ratio
|
||||
|
||||
### Disadvantages
|
||||
|
||||
- Deterministic (no sampling)
|
||||
- May split common words unexpectedly
|
||||
|
||||
## Unigram
|
||||
|
||||
### Algorithm
|
||||
|
||||
1. Start with large vocabulary (all substrings)
|
||||
2. Compute probability of each token
|
||||
3. Remove tokens with minimal loss impact
|
||||
4. Repeat until vocabulary size reached
|
||||
|
||||
### Probabilistic tokenization
|
||||
|
||||
Given vocabulary with probabilities:
|
||||
```
|
||||
P('low') = 0.02
|
||||
P('est') = 0.03
|
||||
P('l') = 0.01
|
||||
P('o') = 0.015
|
||||
...
|
||||
```
|
||||
|
||||
Tokenize "lowest":
|
||||
```
|
||||
Option 1: ['low', 'est']
|
||||
P = 0.02 × 0.03 = 0.0006 ← highest
|
||||
|
||||
Option 2: ['l', 'o', 'w', 'est']
|
||||
P = 0.01 × 0.015 × 0.01 × 0.03 = 0.000000045
|
||||
|
||||
Choose option 1 (highest probability)
|
||||
```
|
||||
|
||||
### Implementation
|
||||
|
||||
```python
|
||||
spm.SentencePieceTrainer.train(
|
||||
input='corpus.txt',
|
||||
model_type='unigram',
|
||||
vocab_size=8000
|
||||
)
|
||||
```
|
||||
|
||||
### Advantages
|
||||
|
||||
- Probabilistic (can sample)
|
||||
- Better for morphologically rich languages
|
||||
- Supports subword regularization
|
||||
|
||||
### Disadvantages
|
||||
|
||||
- Slower training
|
||||
- More complex algorithm
|
||||
|
||||
## Comparison
|
||||
|
||||
| Feature | BPE | Unigram |
|
||||
|---------|-----|---------|
|
||||
| Training speed | Fast | Slow |
|
||||
| Tokenization | Deterministic | Probabilistic |
|
||||
| Sampling | No | Yes |
|
||||
| Typical vocab size | 16k-32k | 8k-32k |
|
||||
| Used by | mBART | T5, ALBERT, XLNet |
|
||||
|
||||
## Subword regularization
|
||||
|
||||
Sample different tokenizations during training for robustness.
|
||||
|
||||
### Enable sampling
|
||||
|
||||
```python
|
||||
sp = spm.SentencePieceProcessor(model_file='m.model')
|
||||
|
||||
# Sample different tokenizations
|
||||
for _ in range(5):
|
||||
pieces = sp.encode('tokenization', out_type=str, enable_sampling=True, alpha=0.1)
|
||||
print(pieces)
|
||||
|
||||
# Output (different each time):
|
||||
# ['▁token', 'ization']
|
||||
# ['▁tok', 'en', 'ization']
|
||||
# ['▁token', 'iz', 'ation']
|
||||
# ['▁to', 'ken', 'ization']
|
||||
# ['▁token', 'ization']
|
||||
```
|
||||
|
||||
### Parameters
|
||||
|
||||
- `alpha`: Regularization strength
|
||||
- 0.0 = deterministic (no sampling)
|
||||
- 0.1 = slight variation
|
||||
- 0.5 = high variation
|
||||
- 1.0 = maximum variation
|
||||
|
||||
### Benefits
|
||||
|
||||
1. **Robustness**: Model learns multiple tokenizations
|
||||
2. **Data augmentation**: More diverse training data
|
||||
3. **Better generalization**: Less overfitting to specific tokenization
|
||||
|
||||
### Use case
|
||||
|
||||
```python
|
||||
# Training loop with regularization
|
||||
for batch in dataloader:
|
||||
# Sample different tokenizations each epoch
|
||||
tokens = sp.encode(batch['text'], enable_sampling=True, alpha=0.1)
|
||||
# Train model...
|
||||
```
|
||||
|
||||
**Used by**: mT5, XLM-RoBERTa
|
||||
|
||||
## NBest encoding
|
||||
|
||||
Get multiple tokenization candidates with scores.
|
||||
|
||||
```python
|
||||
sp = spm.SentencePieceProcessor(model_file='m.model')
|
||||
|
||||
# Get top-5 tokenizations
|
||||
nbest = sp.nbest_encode('tokenization', nbest_size=5, out_type=str)
|
||||
|
||||
for pieces, score in nbest:
|
||||
print(f"{pieces} (log prob: {score:.4f})")
|
||||
|
||||
# Output:
|
||||
# ['▁token', 'ization'] (log prob: -2.34)
|
||||
# ['▁tok', 'en', 'ization'] (log prob: -2.41)
|
||||
# ['▁token', 'iz', 'ation'] (log prob: -2.57)
|
||||
```
|
||||
|
||||
### Use cases
|
||||
|
||||
1. **Ensemble tokenization**: Average over multiple tokenizations
|
||||
2. **Uncertainty estimation**: Check variance in scores
|
||||
3. **Debugging**: Understand tokenizer behavior
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Use Unigram for multilingual** - Better for diverse languages
|
||||
2. **Use BPE for speed** - Faster training and inference
|
||||
3. **Enable subword regularization** - Improves model robustness
|
||||
4. **Set alpha=0.1 for slight variation** - Good balance
|
||||
5. **Use deterministic mode for inference** - Consistent results
|
||||
@@ -0,0 +1,304 @@
|
||||
# SentencePiece Training Guide
|
||||
|
||||
Complete guide to training SentencePiece models.
|
||||
|
||||
## Training workflow
|
||||
|
||||
### Step 1: Prepare corpus
|
||||
|
||||
```bash
|
||||
# Plain text file, one sentence per line (recommended)
|
||||
cat corpus.txt
|
||||
# Hello world
|
||||
# This is a test
|
||||
# SentencePiece is language-independent
|
||||
|
||||
# Or use raw text (SentencePiece handles sentence splitting)
|
||||
```
|
||||
|
||||
### Step 2: Train model
|
||||
|
||||
**Command-line**:
|
||||
```bash
|
||||
spm_train \
|
||||
--input=corpus.txt \
|
||||
--model_prefix=m \
|
||||
--vocab_size=8000 \
|
||||
--model_type=unigram \
|
||||
--character_coverage=0.9995
|
||||
```
|
||||
|
||||
**Python API**:
|
||||
```python
|
||||
import sentencepiece as spm
|
||||
|
||||
spm.SentencePieceTrainer.train(
|
||||
input='corpus.txt',
|
||||
model_prefix='m',
|
||||
vocab_size=8000,
|
||||
model_type='unigram'
|
||||
)
|
||||
```
|
||||
|
||||
**Output**: `m.model` (binary), `m.vocab` (text vocabulary)
|
||||
|
||||
### Step 3: Load and use
|
||||
|
||||
```python
|
||||
sp = spm.SentencePieceProcessor(model_file='m.model')
|
||||
pieces = sp.encode('Test sentence', out_type=str)
|
||||
```
|
||||
|
||||
## Training parameters
|
||||
|
||||
### Core parameters
|
||||
|
||||
```python
|
||||
spm.SentencePieceTrainer.train(
|
||||
# Required
|
||||
input='corpus.txt', # Input corpus
|
||||
model_prefix='output', # Output prefix
|
||||
vocab_size=8000, # Target vocabulary size
|
||||
|
||||
# Algorithm
|
||||
model_type='unigram', # 'unigram', 'bpe', 'char', 'word'
|
||||
|
||||
# Coverage
|
||||
character_coverage=0.9995, # 0.9995 for most, 1.0 for CJK
|
||||
|
||||
# Normalization
|
||||
normalization_rule_name='nmt_nfkc', # 'nmt_nfkc', 'nfkc', 'identity'
|
||||
|
||||
# Performance
|
||||
num_threads=16, # Training threads
|
||||
input_sentence_size=10000000 # Max sentences to load
|
||||
)
|
||||
```
|
||||
|
||||
### Special tokens
|
||||
|
||||
```python
|
||||
spm.SentencePieceTrainer.train(
|
||||
input='corpus.txt',
|
||||
model_prefix='m',
|
||||
vocab_size=32000,
|
||||
|
||||
# Control symbols (special tokens for model control)
|
||||
control_symbols=['<s>', '</s>', '<pad>'],
|
||||
|
||||
# User-defined symbols (never split)
|
||||
user_defined_symbols=['[MASK]', '[SEP]', '[CLS]'],
|
||||
|
||||
# Special token pieces
|
||||
unk_piece='<unk>',
|
||||
bos_piece='<s>',
|
||||
eos_piece='</s>',
|
||||
pad_piece='<pad>',
|
||||
|
||||
# Special token IDs
|
||||
unk_id=0,
|
||||
bos_id=1,
|
||||
eos_id=2,
|
||||
pad_id=3
|
||||
)
|
||||
```
|
||||
|
||||
### Advanced options
|
||||
|
||||
```python
|
||||
spm.SentencePieceTrainer.train(
|
||||
input='corpus.txt',
|
||||
model_prefix='m',
|
||||
vocab_size=32000,
|
||||
|
||||
# Byte fallback (handle unknown chars)
|
||||
byte_fallback=True,
|
||||
|
||||
# Digit handling
|
||||
split_digits=True, # Split digits individually
|
||||
|
||||
# Script splitting
|
||||
split_by_unicode_script=True, # Split by Unicode script
|
||||
split_by_whitespace=True, # Split by whitespace
|
||||
|
||||
# Length constraints
|
||||
max_sentencepiece_length=16, # Max token length
|
||||
|
||||
# Rare word handling
|
||||
min_frequency=2, # Min frequency for token
|
||||
|
||||
# Training size
|
||||
input_sentence_size=10000000, # Max sentences
|
||||
shuffle_input_sentence=True, # Shuffle training data
|
||||
|
||||
# Seed
|
||||
seed_sentencepiece_size=1000000 # Seed vocab size
|
||||
)
|
||||
```
|
||||
|
||||
## Training from Python iterator
|
||||
|
||||
```python
|
||||
import sentencepiece as spm
|
||||
from datasets import load_dataset
|
||||
|
||||
# Load dataset
|
||||
dataset = load_dataset('wikitext', 'wikitext-103-raw-v1', split='train')
|
||||
|
||||
# Create iterator
|
||||
def corpus_iterator():
|
||||
for example in dataset:
|
||||
if example['text'].strip():
|
||||
yield example['text']
|
||||
|
||||
# Train from iterator
|
||||
spm.SentencePieceTrainer.train(
|
||||
sentence_iterator=corpus_iterator(),
|
||||
model_prefix='wiki',
|
||||
vocab_size=32000,
|
||||
model_type='unigram'
|
||||
)
|
||||
```
|
||||
|
||||
## Model types
|
||||
|
||||
### BPE
|
||||
|
||||
```python
|
||||
spm.SentencePieceTrainer.train(
|
||||
input='corpus.txt',
|
||||
model_type='bpe',
|
||||
vocab_size=16000
|
||||
)
|
||||
```
|
||||
|
||||
**Training time**: ~10-15 min for 1GB corpus
|
||||
|
||||
### Unigram (recommended)
|
||||
|
||||
```python
|
||||
spm.SentencePieceTrainer.train(
|
||||
input='corpus.txt',
|
||||
model_type='unigram',
|
||||
vocab_size=8000
|
||||
)
|
||||
```
|
||||
|
||||
**Training time**: ~30-40 min for 1GB corpus
|
||||
|
||||
## Character coverage
|
||||
|
||||
### English/European (0.9995)
|
||||
|
||||
```python
|
||||
spm.SentencePieceTrainer.train(
|
||||
input='en_corpus.txt',
|
||||
character_coverage=0.9995 # Cover 99.95% of chars
|
||||
)
|
||||
```
|
||||
|
||||
Covers: a-z, A-Z, punctuation, common accents
|
||||
|
||||
### CJK (1.0)
|
||||
|
||||
```python
|
||||
spm.SentencePieceTrainer.train(
|
||||
input='zh_corpus.txt',
|
||||
character_coverage=1.0 # Cover ALL characters
|
||||
)
|
||||
```
|
||||
|
||||
Required for: Chinese, Japanese, Korean
|
||||
|
||||
### Multilingual (0.9995-1.0)
|
||||
|
||||
```python
|
||||
spm.SentencePieceTrainer.train(
|
||||
input='multilingual_corpus.txt',
|
||||
character_coverage=0.9995 # Balance coverage/size
|
||||
)
|
||||
```
|
||||
|
||||
## Vocabulary size selection
|
||||
|
||||
| Task | Vocab Size | Rationale |
|
||||
|------|------------|-----------|
|
||||
| English monolingual | 16k-32k | Standard |
|
||||
| Multilingual | 32k-250k | More languages |
|
||||
| CJK | 32k-100k | More characters |
|
||||
| Code | 16k-32k | Similar to English |
|
||||
|
||||
## Normalization rules
|
||||
|
||||
### nmt_nfkc (recommended)
|
||||
|
||||
```python
|
||||
normalization_rule_name='nmt_nfkc'
|
||||
```
|
||||
|
||||
- NFKC Unicode normalization
|
||||
- Whitespace handling
|
||||
- **Recommended for most tasks**
|
||||
|
||||
### identity (no normalization)
|
||||
|
||||
```python
|
||||
normalization_rule_name='identity'
|
||||
```
|
||||
|
||||
- Preserves input exactly
|
||||
- Use for code, case-sensitive tasks
|
||||
|
||||
### nfkc (standard Unicode)
|
||||
|
||||
```python
|
||||
normalization_rule_name='nfkc'
|
||||
```
|
||||
|
||||
- Standard Unicode normalization
|
||||
- Less aggressive than nmt_nfkc
|
||||
|
||||
## Performance optimization
|
||||
|
||||
### Multi-threading
|
||||
|
||||
```python
|
||||
spm.SentencePieceTrainer.train(
|
||||
input='large_corpus.txt',
|
||||
num_threads=32 # Use all cores
|
||||
)
|
||||
```
|
||||
|
||||
**Speedup**: ~4-8× with 16+ cores
|
||||
|
||||
### Sampling input
|
||||
|
||||
```python
|
||||
spm.SentencePieceTrainer.train(
|
||||
input='huge_corpus.txt',
|
||||
input_sentence_size=10000000, # Sample 10M sentences
|
||||
shuffle_input_sentence=True
|
||||
)
|
||||
```
|
||||
|
||||
**For very large corpora** (>10GB)
|
||||
|
||||
### Extremely large corpus
|
||||
|
||||
```python
|
||||
spm.SentencePieceTrainer.train(
|
||||
input='massive_corpus.txt',
|
||||
train_extremely_large_corpus=True, # Enable for >10GB
|
||||
input_sentence_size=100000000
|
||||
)
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Use Unigram for most tasks** - Better for multilingual
|
||||
2. **Set character_coverage=1.0 for CJK** - Required for full coverage
|
||||
3. **Use nmt_nfkc normalization** - Works well for most cases
|
||||
4. **Add user_defined_symbols for special tokens** - BERT-style tokens
|
||||
5. **Enable byte_fallback for robustness** - Handles emojis/rare chars
|
||||
6. **Start with vocab_size=32000** - Good default for most tasks
|
||||
7. **Use multi-threading** - Speeds up training significantly
|
||||
@@ -0,0 +1,158 @@
|
||||
---
|
||||
name: axolotl
|
||||
description: Expert guidance for fine-tuning LLMs with Axolotl - YAML configs, 100+ models, LoRA/QLoRA, DPO/KTO/ORPO/GRPO, multimodal support
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Fine-Tuning, Axolotl, LLM, LoRA, QLoRA, DPO, KTO, ORPO, GRPO, YAML, HuggingFace, DeepSpeed, Multimodal]
|
||||
dependencies: [axolotl, torch, transformers, datasets, peft, accelerate, deepspeed]
|
||||
---
|
||||
|
||||
# Axolotl Skill
|
||||
|
||||
Comprehensive assistance with axolotl development, generated from official documentation.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
This skill should be triggered when:
|
||||
- Working with axolotl
|
||||
- Asking about axolotl features or APIs
|
||||
- Implementing axolotl solutions
|
||||
- Debugging axolotl code
|
||||
- Learning axolotl best practices
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### Common Patterns
|
||||
|
||||
**Pattern 1:** To validate that acceptable data transfer speeds exist for your training job, running NCCL Tests can help pinpoint bottlenecks, for example:
|
||||
|
||||
```
|
||||
./build/all_reduce_perf -b 8 -e 128M -f 2 -g 3
|
||||
```
|
||||
|
||||
**Pattern 2:** Configure your model to use FSDP in the Axolotl yaml. For example:
|
||||
|
||||
```
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: true
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
reshard_after_forward: true
|
||||
```
|
||||
|
||||
**Pattern 3:** The context_parallel_size should be a divisor of the total number of GPUs. For example:
|
||||
|
||||
```
|
||||
context_parallel_size
|
||||
```
|
||||
|
||||
**Pattern 4:** For example: - With 8 GPUs and no sequence parallelism: 8 different batches processed per step - With 8 GPUs and context_parallel_size=4: Only 2 different batches processed per step (each split across 4 GPUs) - If your per-GPU micro_batch_size is 2, the global batch size decreases from 16 to 4
|
||||
|
||||
```
|
||||
context_parallel_size=4
|
||||
```
|
||||
|
||||
**Pattern 5:** Setting save_compressed: true in your configuration enables saving models in a compressed format, which: - Reduces disk space usage by approximately 40% - Maintains compatibility with vLLM for accelerated inference - Maintains compatibility with llmcompressor for further optimization (example: quantization)
|
||||
|
||||
```
|
||||
save_compressed: true
|
||||
```
|
||||
|
||||
**Pattern 6:** Note It is not necessary to place your integration in the integrations folder. It can be in any location, so long as it’s installed in a package in your python env. See this repo for an example: https://github.com/axolotl-ai-cloud/diff-transformer
|
||||
|
||||
```
|
||||
integrations
|
||||
```
|
||||
|
||||
**Pattern 7:** Handle both single-example and batched data. - single example: sample[‘input_ids’] is a list[int] - batched data: sample[‘input_ids’] is a list[list[int]]
|
||||
|
||||
```
|
||||
utils.trainer.drop_long_seq(sample, sequence_len=2048, min_sequence_len=2)
|
||||
```
|
||||
|
||||
### Example Code Patterns
|
||||
|
||||
**Example 1** (python):
|
||||
```python
|
||||
cli.cloud.modal_.ModalCloud(config, app=None)
|
||||
```
|
||||
|
||||
**Example 2** (python):
|
||||
```python
|
||||
cli.cloud.modal_.run_cmd(cmd, run_folder, volumes=None)
|
||||
```
|
||||
|
||||
**Example 3** (python):
|
||||
```python
|
||||
core.trainers.base.AxolotlTrainer(
|
||||
*_args,
|
||||
bench_data_collator=None,
|
||||
eval_data_collator=None,
|
||||
dataset_tags=None,
|
||||
**kwargs,
|
||||
)
|
||||
```
|
||||
|
||||
**Example 4** (python):
|
||||
```python
|
||||
core.trainers.base.AxolotlTrainer.log(logs, start_time=None)
|
||||
```
|
||||
|
||||
**Example 5** (python):
|
||||
```python
|
||||
prompt_strategies.input_output.RawInputOutputPrompter()
|
||||
```
|
||||
|
||||
## Reference Files
|
||||
|
||||
This skill includes comprehensive documentation in `references/`:
|
||||
|
||||
- **api.md** - Api documentation
|
||||
- **dataset-formats.md** - Dataset-Formats documentation
|
||||
- **other.md** - Other documentation
|
||||
|
||||
Use `view` to read specific reference files when detailed information is needed.
|
||||
|
||||
## Working with This Skill
|
||||
|
||||
### For Beginners
|
||||
Start with the getting_started or tutorials reference files for foundational concepts.
|
||||
|
||||
### For Specific Features
|
||||
Use the appropriate category reference file (api, guides, etc.) for detailed information.
|
||||
|
||||
### For Code Examples
|
||||
The quick reference section above contains common patterns extracted from the official docs.
|
||||
|
||||
## Resources
|
||||
|
||||
### references/
|
||||
Organized documentation extracted from official sources. These files contain:
|
||||
- Detailed explanations
|
||||
- Code examples with language annotations
|
||||
- Links to original documentation
|
||||
- Table of contents for quick navigation
|
||||
|
||||
### scripts/
|
||||
Add helper scripts here for common automation tasks.
|
||||
|
||||
### assets/
|
||||
Add templates, boilerplate, or example projects here.
|
||||
|
||||
## Notes
|
||||
|
||||
- This skill was automatically generated from official documentation
|
||||
- Reference files preserve the structure and examples from source docs
|
||||
- Code examples include language detection for better syntax highlighting
|
||||
- Quick reference patterns are extracted from common usage examples in the docs
|
||||
|
||||
## Updating
|
||||
|
||||
To refresh this skill with updated documentation:
|
||||
1. Re-run the scraper with the same configuration
|
||||
2. The skill will be rebuilt with the latest information
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,15 @@
|
||||
# Axolotl Documentation Index
|
||||
|
||||
## Categories
|
||||
|
||||
### Api
|
||||
**File:** `api.md`
|
||||
**Pages:** 150
|
||||
|
||||
### Dataset-Formats
|
||||
**File:** `dataset-formats.md`
|
||||
**Pages:** 9
|
||||
|
||||
### Other
|
||||
**File:** `other.md`
|
||||
**Pages:** 26
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,80 @@
|
||||
---
|
||||
name: llama-factory
|
||||
description: Expert guidance for fine-tuning LLMs with LLaMA-Factory - WebUI no-code, 100+ models, 2/3/4/5/6/8-bit QLoRA, multimodal support
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Fine-Tuning, LLaMA Factory, LLM, WebUI, No-Code, QLoRA, LoRA, Multimodal, HuggingFace, Llama, Qwen, Gemma]
|
||||
dependencies: [llmtuner, torch, transformers, datasets, peft, accelerate, gradio]
|
||||
---
|
||||
|
||||
# Llama-Factory Skill
|
||||
|
||||
Comprehensive assistance with llama-factory development, generated from official documentation.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
This skill should be triggered when:
|
||||
- Working with llama-factory
|
||||
- Asking about llama-factory features or APIs
|
||||
- Implementing llama-factory solutions
|
||||
- Debugging llama-factory code
|
||||
- Learning llama-factory best practices
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### Common Patterns
|
||||
|
||||
*Quick reference patterns will be added as you use the skill.*
|
||||
|
||||
## Reference Files
|
||||
|
||||
This skill includes comprehensive documentation in `references/`:
|
||||
|
||||
- **_images.md** - Images documentation
|
||||
- **advanced.md** - Advanced documentation
|
||||
- **getting_started.md** - Getting Started documentation
|
||||
- **other.md** - Other documentation
|
||||
|
||||
Use `view` to read specific reference files when detailed information is needed.
|
||||
|
||||
## Working with This Skill
|
||||
|
||||
### For Beginners
|
||||
Start with the getting_started or tutorials reference files for foundational concepts.
|
||||
|
||||
### For Specific Features
|
||||
Use the appropriate category reference file (api, guides, etc.) for detailed information.
|
||||
|
||||
### For Code Examples
|
||||
The quick reference section above contains common patterns extracted from the official docs.
|
||||
|
||||
## Resources
|
||||
|
||||
### references/
|
||||
Organized documentation extracted from official sources. These files contain:
|
||||
- Detailed explanations
|
||||
- Code examples with language annotations
|
||||
- Links to original documentation
|
||||
- Table of contents for quick navigation
|
||||
|
||||
### scripts/
|
||||
Add helper scripts here for common automation tasks.
|
||||
|
||||
### assets/
|
||||
Add templates, boilerplate, or example projects here.
|
||||
|
||||
## Notes
|
||||
|
||||
- This skill was automatically generated from official documentation
|
||||
- Reference files preserve the structure and examples from source docs
|
||||
- Code examples include language detection for better syntax highlighting
|
||||
- Quick reference patterns are extracted from common usage examples in the docs
|
||||
|
||||
## Updating
|
||||
|
||||
To refresh this skill with updated documentation:
|
||||
1. Re-run the scraper with the same configuration
|
||||
2. The skill will be rebuilt with the latest information
|
||||
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
# Llama-Factory - Images
|
||||
|
||||
**Pages:** 3
|
||||
|
||||
---
|
||||
|
||||
##
|
||||
|
||||
**URL:** https://llamafactory.readthedocs.io/en/latest/_images/logo.png
|
||||
|
||||
---
|
||||
|
||||
##
|
||||
|
||||
**URL:** https://llamafactory.readthedocs.io/en/latest/_images/quantization_0.png
|
||||
|
||||
---
|
||||
|
||||
##
|
||||
|
||||
**URL:** https://llamafactory.readthedocs.io/en/latest/_images/webui_0.png
|
||||
|
||||
---
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,349 @@
|
||||
# Llama-Factory - Getting Started
|
||||
|
||||
**Pages:** 7
|
||||
|
||||
---
|
||||
|
||||
## Installation¶
|
||||
|
||||
**URL:** https://llamafactory.readthedocs.io/en/latest/getting_started/installation.html
|
||||
|
||||
**Contents:**
|
||||
- Installation¶
|
||||
- Linux¶
|
||||
- CUDA 安装¶
|
||||
- Windows¶
|
||||
- CUDA 安装¶
|
||||
- LLaMA-Factory 安装¶
|
||||
- LLaMA-Factory 校验¶
|
||||
- LLaMA-Factory 高级选项¶
|
||||
- Windows¶
|
||||
- QLoRA¶
|
||||
|
||||
CUDA 是由 NVIDIA 创建的一个并行计算平台和编程模型,它让开发者可以使用 NVIDIA 的 GPU 进行高性能的并行计算。
|
||||
|
||||
首先,在 https://developer.nvidia.com/cuda-gpus 查看您的 GPU 是否支持CUDA
|
||||
|
||||
保证当前 Linux 版本支持CUDA. 在命令行中输入 uname -m && cat /etc/*release,应当看到类似的输出
|
||||
|
||||
检查是否安装了 gcc . 在命令行中输入 gcc --version ,应当看到类似的输出
|
||||
|
||||
在以下网址下载所需的 CUDA,这里推荐12.2版本。 https://developer.nvidia.com/cuda-gpus 注意需要根据上述输出选择正确版本
|
||||
|
||||
如果您之前安装过 CUDA(例如为12.1版本),需要先使用 sudo /usr/local/cuda-12.1/bin/cuda-uninstaller 卸载。如果该命令无法运行,可以直接:
|
||||
|
||||
卸载完成后运行以下命令并根据提示继续安装:
|
||||
|
||||
注意:在确定 CUDA 自带驱动版本与 GPU 是否兼容之前,建议取消 Driver 的安装。
|
||||
|
||||
完成后输入 nvcc -V 检查是否出现对应的版本号,若出现则安装完成。
|
||||
|
||||
打开 设置 ,在 关于 中找到 Windows 规格 保证系统版本在以下列表中:
|
||||
|
||||
Microsoft Windows 11 21H2
|
||||
|
||||
Microsoft Windows 11 22H2-SV2
|
||||
|
||||
Microsoft Windows 11 23H2
|
||||
|
||||
Microsoft Windows 10 21H2
|
||||
|
||||
Microsoft Windows 10 22H2
|
||||
|
||||
Microsoft Windows Server 2022
|
||||
|
||||
打开 cmd 输入 nvcc -V ,若出现类似内容则安装成功。
|
||||
|
||||
否则,检查系统环境变量,保证 CUDA 被正确导入。
|
||||
|
||||
在安装 LLaMA-Factory 之前,请确保您安装了下列依赖:
|
||||
|
||||
运行以下指令以安装 LLaMA-Factory 及其依赖:
|
||||
|
||||
如果出现环境冲突,请尝试使用 pip install --no-deps -e . 解决
|
||||
|
||||
完成安装后,可以通过使用 llamafactory-cli version 来快速校验安装是否成功
|
||||
|
||||
如果您能成功看到类似下面的界面,就说明安装成功了。
|
||||
|
||||
如果您想在 Windows 上启用量化 LoRA(QLoRA),请根据您的 CUDA 版本选择适当的 bitsandbytes 发行版本。
|
||||
|
||||
如果您要在 Windows 平台上启用 FlashAttention-2,请根据您的 CUDA 版本选择适当的 flash-attention 发行版本。
|
||||
|
||||
开源深度学习框架 PyTorch,广泛用于机器学习和人工智能研究中。
|
||||
|
||||
提供了加载 Qwen v1 模型所需的包。
|
||||
|
||||
魔搭社区,提供了预训练模型和数据集的下载途径。
|
||||
|
||||
开源训练跟踪工具 SwanLab,用于记录与可视化训练过程
|
||||
|
||||
用于 LLaMA Factory 开发维护。
|
||||
|
||||
---
|
||||
|
||||
## WebUI¶
|
||||
|
||||
**URL:** https://llamafactory.readthedocs.io/en/latest/getting_started/webui.html
|
||||
|
||||
**Contents:**
|
||||
- WebUI¶
|
||||
- 训练¶
|
||||
- 评估预测与对话¶
|
||||
- 导出¶
|
||||
|
||||
LLaMA-Factory 支持通过 WebUI 零代码微调大语言模型。 在完成 安装 后,您可以通过以下指令进入 WebUI:
|
||||
|
||||
WebUI 主要分为四个界面:训练、评估与预测、对话、导出。
|
||||
|
||||
随后,您可以点击 开始 按钮开始训练模型。
|
||||
|
||||
关于断点重连:适配器断点保存于 output_dir 目录下,请指定 适配器路径 以加载断点继续训练。
|
||||
|
||||
如果您需要使用自定义数据集,请在 data/data_info.json 中添加自定义数据集描述并确保 数据集格式 正确,否则可能会导致训练失败。
|
||||
|
||||
模型训练完毕后,您可以通过在评估与预测界面通过指定 模型 及 适配器 的路径在指定数据集上进行评估。
|
||||
|
||||
您也可以通过在对话界面指定 模型、 适配器 及 推理引擎 后输入对话内容与模型进行对话观察效果。
|
||||
|
||||
如果您对模型效果满意并需要导出模型,您可以在导出界面通过指定 模型、 适配器、 分块大小、 导出量化等级及校准数据集、 导出设备、 导出目录 等参数后点击 导出 按钮导出模型。
|
||||
|
||||
---
|
||||
|
||||
## Merge¶
|
||||
|
||||
**URL:** https://llamafactory.readthedocs.io/en/latest/getting_started/merge_lora.html
|
||||
|
||||
**Contents:**
|
||||
- Merge¶
|
||||
- 合并¶
|
||||
- 量化¶
|
||||
|
||||
当我们基于预训练模型训练好 LoRA 适配器后,我们不希望在每次推理的时候分别加载预训练模型和 LoRA 适配器,因此我们需要将预训练模型和 LoRA 适配器合并导出成一个模型,并根据需要选择是否量化。根据是否量化以及量化算法的不同,导出的配置文件也有所区别。
|
||||
|
||||
您可以通过 llamafactory-cli export merge_config.yaml 指令来合并模型。其中 merge_config.yaml 需要您根据不同情况进行配置。
|
||||
|
||||
examples/merge_lora/llama3_lora_sft.yaml 提供了合并时的配置示例。
|
||||
|
||||
模型 model_name_or_path 需要存在且与 template 相对应。 adapter_name_or_path 需要与微调中的适配器输出路径 output_dir 相对应。
|
||||
|
||||
合并 LoRA 适配器时,不要使用量化模型或指定量化位数。您可以使用本地或下载的未量化的预训练模型进行合并。
|
||||
|
||||
在完成模型合并并获得完整模型后,为了优化部署效果,人们通常会基于显存占用、使用成本和推理速度等因素,选择通过量化技术对模型进行压缩,从而实现更高效的部署。
|
||||
|
||||
量化(Quantization)通过数据精度压缩有效地减少了显存使用并加速推理。LLaMA-Factory 支持多种量化方法,包括:
|
||||
|
||||
GPTQ 等后训练量化方法(Post Training Quantization)是一种在训练后对预训练模型进行量化的方法。我们通过量化技术将高精度表示的预训练模型转换为低精度的模型,从而在避免过多损失模型性能的情况下减少显存占用并加速推理,我们希望低精度数据类型在有限的表示范围内尽可能地接近高精度数据类型的表示,因此我们需要指定量化位数 export_quantization_bit 以及校准数据集 export_quantization_dataset。
|
||||
|
||||
model_name_or_path: 预训练模型的名称或路径
|
||||
|
||||
export_quantization_bit: 量化位数
|
||||
|
||||
export_quantization_dataset: 量化校准数据集
|
||||
|
||||
export_size: 最大导出模型文件大小
|
||||
|
||||
export_legacy_format: 是否使用旧格式导出
|
||||
|
||||
QLoRA 是一种在 4-bit 量化模型基础上使用 LoRA 方法进行训练的技术。它在极大地保持了模型性能的同时大幅减少了显存占用和推理时间。
|
||||
|
||||
不要使用量化模型或设置量化位数 quantization_bit
|
||||
|
||||
---
|
||||
|
||||
## Inference¶
|
||||
|
||||
**URL:** https://llamafactory.readthedocs.io/en/latest/getting_started/inference.html
|
||||
|
||||
**Contents:**
|
||||
- Inference¶
|
||||
- 原始模型推理配置¶
|
||||
- 微调模型推理配置¶
|
||||
- 多模态模型¶
|
||||
- 批量推理¶
|
||||
- 数据集¶
|
||||
- api¶
|
||||
|
||||
LLaMA-Factory 支持多种推理方式。
|
||||
|
||||
您可以使用 llamafactory-cli chat inference_config.yaml 或 llamafactory-cli webchat inference_config.yaml 进行推理与模型对话。对话时配置文件只需指定原始模型 model_name_or_path 和 template ,并根据是否是微调模型指定 adapter_name_or_path 和 finetuning_type。
|
||||
|
||||
如果您希望向模型输入大量数据集并保存推理结果,您可以启动 vllm 推理引擎对大量数据集进行快速的批量推理。您也可以通过 部署 api 服务的形式通过 api 调用来进行批量推理。
|
||||
|
||||
默认情况下,模型推理将使用 Huggingface 引擎。 您也可以指定 infer_backend: vllm 以使用 vllm 推理引擎以获得更快的推理速度。
|
||||
|
||||
使用任何方式推理时,模型 model_name_or_path 需要存在且与 template 相对应。
|
||||
|
||||
对于原始模型推理, inference_config.yaml 中 只需指定原始模型 model_name_or_path 和 template 即可。
|
||||
|
||||
对于微调模型推理,除原始模型和模板外,还需要指定适配器路径 adapter_name_or_path 和微调类型 finetuning_type。
|
||||
|
||||
对于多模态模型,您可以运行以下指令进行推理。
|
||||
|
||||
examples/inference/llava1_5.yaml 的配置示例如下:
|
||||
|
||||
您可以通过以下指令启动 vllm 推理引擎并使用数据集进行批量推理:
|
||||
|
||||
如果您需要使用 api 进行批量推理,您只需指定模型、适配器(可选)、模板、微调方式等信息。
|
||||
|
||||
下面是一个启动并调用 api 服务的示例:
|
||||
|
||||
您可以使用 API_PORT=8000 CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.yaml 启动 api 服务并运行以下示例程序进行调用:
|
||||
|
||||
---
|
||||
|
||||
## Eval¶
|
||||
|
||||
**URL:** https://llamafactory.readthedocs.io/en/latest/getting_started/eval.html
|
||||
|
||||
**Contents:**
|
||||
- Eval¶
|
||||
- 通用能力评估¶
|
||||
- NLG 评估¶
|
||||
- 评估相关参数¶
|
||||
|
||||
在完成模型训练后,您可以通过 llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml 来评估模型效果。
|
||||
|
||||
配置示例文件 examples/train_lora/llama3_lora_eval.yaml 具体如下:
|
||||
|
||||
此外,您还可以通过 llamafactory-cli train examples/extras/nlg_eval/llama3_lora_predict.yaml 来获得模型的 BLEU 和 ROUGE 分数以评价模型生成质量。
|
||||
|
||||
配置示例文件 examples/extras/nlg_eval/llama3_lora_predict.yaml 具体如下:
|
||||
|
||||
同样,您也通过在指令 python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo 中指定模型、数据集以使用 vllm 推理框架以取得更快的推理速度。
|
||||
|
||||
评估任务的名称,可选项有 mmlu_test, ceval_validation, cmmlu_test
|
||||
|
||||
包含评估数据集的文件夹路径,默认值为 evaluation。
|
||||
|
||||
用于数据加载器的随机种子,默认值为 42。
|
||||
|
||||
评估使用的语言,可选值为 en、 zh。默认值为 en。
|
||||
|
||||
few-shot 的示例数量,默认值为 5。
|
||||
|
||||
保存评估结果的路径,默认值为 None。 如果该路径已经存在则会抛出错误。
|
||||
|
||||
评估数据集的下载模式,默认值为 DownloadMode.REUSE_DATASET_IF_EXISTS。如果数据集已经存在则重复使用,否则则下载。
|
||||
|
||||
---
|
||||
|
||||
## Data Preparation¶
|
||||
|
||||
**URL:** https://llamafactory.readthedocs.io/en/latest/getting_started/data_preparation.html
|
||||
|
||||
**Contents:**
|
||||
- Data Preparation¶
|
||||
- Alpaca¶
|
||||
- 指令监督微调数据集¶
|
||||
- 预训练数据集¶
|
||||
- 偏好数据集¶
|
||||
- KTO 数据集¶
|
||||
- 多模态数据集¶
|
||||
- 图像数据集¶
|
||||
- 视频数据集¶
|
||||
- 音频数据集¶
|
||||
|
||||
dataset_info.json 包含了所有经过预处理的 本地数据集 以及 在线数据集。如果您希望使用自定义数据集,请 务必 在 dataset_info.json 文件中添加对数据集及其内容的定义。
|
||||
|
||||
目前我们支持 Alpaca 格式和 ShareGPT 格式的数据集。
|
||||
|
||||
指令监督微调(Instruct Tuning)通过让模型学习详细的指令以及对应的回答来优化模型在特定指令下的表现。
|
||||
|
||||
instruction 列对应的内容为人类指令, input 列对应的内容为人类输入, output 列对应的内容为模型回答。下面是一个例子
|
||||
|
||||
在进行指令监督微调时, instruction 列对应的内容会与 input 列对应的内容拼接后作为最终的人类输入,即人类输入为 instruction\ninput。而 output 列对应的内容为模型回答。 在上面的例子中,人类的最终输入是:
|
||||
|
||||
如果指定, system 列对应的内容将被作为系统提示词。
|
||||
|
||||
history 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮对话的指令和回答。注意在指令监督微调时,历史消息中的回答内容也会被用于模型学习。
|
||||
|
||||
下面提供一个 alpaca 格式 多轮 对话的例子,对于单轮对话只需省略 history 列即可。
|
||||
|
||||
对于上述格式的数据, dataset_info.json 中的 数据集描述 应为:
|
||||
|
||||
大语言模型通过学习未被标记的文本进行预训练,从而学习语言的表征。通常,预训练数据集从互联网上获得,因为互联网上提供了大量的不同领域的文本信息,有助于提升模型的泛化能力。 预训练数据集文本描述格式如下:
|
||||
|
||||
在预训练时,只有 text 列中的 内容 (即document)会用于模型学习。
|
||||
|
||||
对于上述格式的数据, dataset_info.json 中的 数据集描述 应为:
|
||||
|
||||
偏好数据集用于奖励模型训练、DPO 训练和 ORPO 训练。对于系统指令和人类输入,偏好数据集给出了一个更优的回答和一个更差的回答。
|
||||
|
||||
一些研究 表明通过让模型学习“什么更好”可以使得模型更加迎合人类的需求。 甚至可以使得参数相对较少的模型的表现优于参数更多的模型。
|
||||
|
||||
偏好数据集需要在 chosen 列中提供更优的回答,并在 rejected 列中提供更差的回答,在一轮问答中其格式如下:
|
||||
|
||||
对于上述格式的数据,dataset_info.json 中的 数据集描述 应为:
|
||||
|
||||
KTO数据集与偏好数据集类似,但不同于给出一个更优的回答和一个更差的回答,KTO数据集对每一轮问答只给出一个 true/false 的 label。 除了 instruction 以及 input 组成的人类最终输入和模型回答 output ,KTO 数据集还需要额外添加一个 kto_tag 列(true/false)来表示人类的反馈。
|
||||
|
||||
对于上述格式的数据, dataset_info.json 中的 数据集描述 应为:
|
||||
|
||||
目前我们支持 多模态图像数据集、 视频数据集 以及 音频数据集 的输入。
|
||||
|
||||
多模态图像数据集需要额外添加一个 images 列,包含输入图像的路径。 注意图片的数量必须与文本中所有 <image> 标记的数量严格一致。
|
||||
|
||||
对于上述格式的数据, dataset_info.json 中的 数据集描述 应为:
|
||||
|
||||
多模态视频数据集需要额外添加一个 videos 列,包含输入视频的路径。 注意视频的数量必须与文本中所有 <video> 标记的数量严格一致。
|
||||
|
||||
对于上述格式的数据, dataset_info.json 中的 数据集描述 应为:
|
||||
|
||||
多模态音频数据集需要额外添加一个 audio 列,包含输入图像的路径。 注意音频的数量必须与文本中所有 <audio> 标记的数量严格一致。
|
||||
|
||||
对于上述格式的数据, dataset_info.json 中的 数据集描述 应为:
|
||||
|
||||
ShareGPT 格式中的 KTO数据集(样例)和多模态数据集(样例) 与 Alpaca 格式的类似。
|
||||
|
||||
预训练数据集不支持 ShareGPT 格式。
|
||||
|
||||
相比 alpaca 格式的数据集, sharegpt 格式支持 更多 的角色种类,例如 human、gpt、observation、function 等等。它们构成一个对象列表呈现在 conversations 列中。 下面是 sharegpt 格式的一个例子:
|
||||
|
||||
注意其中 human 和 observation 必须出现在奇数位置,gpt 和 function 必须出现在偶数位置。
|
||||
|
||||
对于上述格式的数据, dataset_info.json 中的 数据集描述 应为:
|
||||
|
||||
Sharegpt 格式的偏好数据集同样需要在 chosen 列中提供更优的消息,并在 rejected 列中提供更差的消息。 下面是一个例子:
|
||||
|
||||
对于上述格式的数据,dataset_info.json 中的 数据集描述 应为:
|
||||
|
||||
OpenAI 格式仅仅是 sharegpt 格式的一种特殊情况,其中第一条消息可能是系统提示词。
|
||||
|
||||
对于上述格式的数据, dataset_info.json 中的 数据集描述 应为:
|
||||
|
||||
---
|
||||
|
||||
## Supervised Fine-tuning¶
|
||||
|
||||
**URL:** https://llamafactory.readthedocs.io/en/latest/getting_started/sft.html
|
||||
|
||||
**Contents:**
|
||||
- Supervised Fine-tuning¶
|
||||
- 命令行¶
|
||||
|
||||
您可以使用以下命令使用 examples/train_lora/llama3_lora_sft.yaml 中的参数进行微调:
|
||||
|
||||
也可以通过追加参数更新 yaml 文件中的参数:
|
||||
|
||||
LLaMA-Factory 默认使用所有可见的计算设备。根据需求可通过 CUDA_VISIBLE_DEVICES 或 ASCEND_RT_VISIBLE_DEVICES 指定计算设备。
|
||||
|
||||
examples/train_lora/llama3_lora_sft.yaml 提供了微调时的配置示例。该配置指定了模型参数、微调方法参数、数据集参数以及评估参数等。您需要根据自身需求自行配置。
|
||||
|
||||
模型 model_name_or_path 、数据集 dataset 需要存在且与 template 相对应。
|
||||
|
||||
训练阶段,可选: rm(reward modeling), pt(pretrain), sft(Supervised Fine-Tuning), PPO, DPO, KTO, ORPO
|
||||
|
||||
微调方式。可选: freeze, lora, full
|
||||
|
||||
采取LoRA方法的目标模块,默认值为 all。
|
||||
|
||||
数据集模板,请保证数据集模板与模型相对应。
|
||||
|
||||
per_device_train_batch_size
|
||||
|
||||
gradient_accumulation_steps
|
||||
|
||||
学习率曲线,可选 linear, cosine, polynomial, constant 等。
|
||||
|
||||
---
|
||||
@@ -0,0 +1,19 @@
|
||||
# Llama-Factory Documentation Index
|
||||
|
||||
## Categories
|
||||
|
||||
### Images
|
||||
**File:** `_images.md`
|
||||
**Pages:** 3
|
||||
|
||||
### Advanced
|
||||
**File:** `advanced.md`
|
||||
**Pages:** 14
|
||||
|
||||
### Getting Started
|
||||
**File:** `getting_started.md`
|
||||
**Pages:** 7
|
||||
|
||||
### Other
|
||||
**File:** `other.md`
|
||||
**Pages:** 1
|
||||
@@ -0,0 +1,31 @@
|
||||
# Llama-Factory - Other
|
||||
|
||||
**Pages:** 1
|
||||
|
||||
---
|
||||
|
||||
## Welcome to LLaMA Factory!¶
|
||||
|
||||
**URL:** https://llamafactory.readthedocs.io/en/latest/
|
||||
|
||||
**Contents:**
|
||||
- Welcome to LLaMA Factory!¶
|
||||
- Documentation¶
|
||||
|
||||
LLaMA Factory is an easy-to-use and efficient platform for training and fine-tuning large language models. With LLaMA Factory, you can fine-tune hundreds of pre-trained models locally without writing any code. Framework features include:
|
||||
|
||||
Models: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
|
||||
|
||||
Trainers: (incremental) pre-training, (multimodal) instruction supervision fine-tuning, reward model training, PPO training, DPO training, KTO training, ORPO training, etc.
|
||||
|
||||
Computation Precision: 16-bit full-parameter fine-tuning, frozen fine-tuning, LoRA fine-tuning, and 2/3/4/5/6/8-bit QLoRA fine-tuning based on AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
|
||||
|
||||
Optimization Algorithms: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, and PiSSA.
|
||||
|
||||
Acceleration Operators: FlashAttention-2 and Unsloth.
|
||||
|
||||
Inference Engines: Transformers and vLLM.
|
||||
|
||||
Experiment Monitors: LlamaBoard, TensorBoard, Wandb, MLflow, SwanLab etc.
|
||||
|
||||
---
|
||||
@@ -0,0 +1,431 @@
|
||||
---
|
||||
name: peft-fine-tuning
|
||||
description: Parameter-efficient fine-tuning for LLMs using LoRA, QLoRA, and 25+ methods. Use when fine-tuning large models (7B-70B) with limited GPU memory, when you need to train <1% of parameters with minimal accuracy loss, or for multi-adapter serving. HuggingFace's official library integrated with transformers ecosystem.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Fine-Tuning, PEFT, LoRA, QLoRA, Parameter-Efficient, Adapters, Low-Rank, Memory Optimization, Multi-Adapter]
|
||||
dependencies: [peft>=0.13.0, transformers>=4.45.0, torch>=2.0.0, bitsandbytes>=0.43.0]
|
||||
---
|
||||
|
||||
# PEFT (Parameter-Efficient Fine-Tuning)
|
||||
|
||||
Fine-tune LLMs by training <1% of parameters using LoRA, QLoRA, and 25+ adapter methods.
|
||||
|
||||
## When to use PEFT
|
||||
|
||||
**Use PEFT/LoRA when:**
|
||||
- Fine-tuning 7B-70B models on consumer GPUs (RTX 4090, A100)
|
||||
- Need to train <1% parameters (6MB adapters vs 14GB full model)
|
||||
- Want fast iteration with multiple task-specific adapters
|
||||
- Deploying multiple fine-tuned variants from one base model
|
||||
|
||||
**Use QLoRA (PEFT + quantization) when:**
|
||||
- Fine-tuning 70B models on single 24GB GPU
|
||||
- Memory is the primary constraint
|
||||
- Can accept ~5% quality trade-off vs full fine-tuning
|
||||
|
||||
**Use full fine-tuning instead when:**
|
||||
- Training small models (<1B parameters)
|
||||
- Need maximum quality and have compute budget
|
||||
- Significant domain shift requires updating all weights
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Basic installation
|
||||
pip install peft
|
||||
|
||||
# With quantization support (recommended)
|
||||
pip install peft bitsandbytes
|
||||
|
||||
# Full stack
|
||||
pip install peft transformers accelerate bitsandbytes datasets
|
||||
```
|
||||
|
||||
### LoRA fine-tuning (standard)
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
|
||||
from peft import get_peft_model, LoraConfig, TaskType
|
||||
from datasets import load_dataset
|
||||
|
||||
# Load base model
|
||||
model_name = "meta-llama/Llama-3.1-8B"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# LoRA configuration
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
r=16, # Rank (8-64, higher = more capacity)
|
||||
lora_alpha=32, # Scaling factor (typically 2*r)
|
||||
lora_dropout=0.05, # Dropout for regularization
|
||||
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], # Attention layers
|
||||
bias="none" # Don't train biases
|
||||
)
|
||||
|
||||
# Apply LoRA
|
||||
model = get_peft_model(model, lora_config)
|
||||
model.print_trainable_parameters()
|
||||
# Output: trainable params: 13,631,488 || all params: 8,043,307,008 || trainable%: 0.17%
|
||||
|
||||
# Prepare dataset
|
||||
dataset = load_dataset("databricks/databricks-dolly-15k", split="train")
|
||||
|
||||
def tokenize(example):
|
||||
text = f"### Instruction:\n{example['instruction']}\n\n### Response:\n{example['response']}"
|
||||
return tokenizer(text, truncation=True, max_length=512, padding="max_length")
|
||||
|
||||
tokenized = dataset.map(tokenize, remove_columns=dataset.column_names)
|
||||
|
||||
# Training
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./lora-llama",
|
||||
num_train_epochs=3,
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=2e-4,
|
||||
fp16=True,
|
||||
logging_steps=10,
|
||||
save_strategy="epoch"
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=tokenized,
|
||||
data_collator=lambda data: {"input_ids": torch.stack([f["input_ids"] for f in data]),
|
||||
"attention_mask": torch.stack([f["attention_mask"] for f in data]),
|
||||
"labels": torch.stack([f["input_ids"] for f in data])}
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# Save adapter only (6MB vs 16GB)
|
||||
model.save_pretrained("./lora-llama-adapter")
|
||||
```
|
||||
|
||||
### QLoRA fine-tuning (memory-efficient)
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
|
||||
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
|
||||
|
||||
# 4-bit quantization config
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4", # NormalFloat4 (best for LLMs)
|
||||
bnb_4bit_compute_dtype="bfloat16", # Compute in bf16
|
||||
bnb_4bit_use_double_quant=True # Nested quantization
|
||||
)
|
||||
|
||||
# Load quantized model
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.1-70B",
|
||||
quantization_config=bnb_config,
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
# Prepare for training (enables gradient checkpointing)
|
||||
model = prepare_model_for_kbit_training(model)
|
||||
|
||||
# LoRA config for QLoRA
|
||||
lora_config = LoraConfig(
|
||||
r=64, # Higher rank for 70B
|
||||
lora_alpha=128,
|
||||
lora_dropout=0.1,
|
||||
target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
model = get_peft_model(model, lora_config)
|
||||
# 70B model now fits on single 24GB GPU!
|
||||
```
|
||||
|
||||
## LoRA parameter selection
|
||||
|
||||
### Rank (r) - capacity vs efficiency
|
||||
|
||||
| Rank | Trainable Params | Memory | Quality | Use Case |
|
||||
|------|-----------------|--------|---------|----------|
|
||||
| 4 | ~3M | Minimal | Lower | Simple tasks, prototyping |
|
||||
| **8** | ~7M | Low | Good | **Recommended starting point** |
|
||||
| **16** | ~14M | Medium | Better | **General fine-tuning** |
|
||||
| 32 | ~27M | Higher | High | Complex tasks |
|
||||
| 64 | ~54M | High | Highest | Domain adaptation, 70B models |
|
||||
|
||||
### Alpha (lora_alpha) - scaling factor
|
||||
|
||||
```python
|
||||
# Rule of thumb: alpha = 2 * rank
|
||||
LoraConfig(r=16, lora_alpha=32) # Standard
|
||||
LoraConfig(r=16, lora_alpha=16) # Conservative (lower learning rate effect)
|
||||
LoraConfig(r=16, lora_alpha=64) # Aggressive (higher learning rate effect)
|
||||
```
|
||||
|
||||
### Target modules by architecture
|
||||
|
||||
```python
|
||||
# Llama / Mistral / Qwen
|
||||
target_modules = ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
||||
|
||||
# GPT-2 / GPT-Neo
|
||||
target_modules = ["c_attn", "c_proj", "c_fc"]
|
||||
|
||||
# Falcon
|
||||
target_modules = ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"]
|
||||
|
||||
# BLOOM
|
||||
target_modules = ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"]
|
||||
|
||||
# Auto-detect all linear layers
|
||||
target_modules = "all-linear" # PEFT 0.6.0+
|
||||
```
|
||||
|
||||
## Loading and merging adapters
|
||||
|
||||
### Load trained adapter
|
||||
|
||||
```python
|
||||
from peft import PeftModel, AutoPeftModelForCausalLM
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
# Option 1: Load with PeftModel
|
||||
base_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B")
|
||||
model = PeftModel.from_pretrained(base_model, "./lora-llama-adapter")
|
||||
|
||||
# Option 2: Load directly (recommended)
|
||||
model = AutoPeftModelForCausalLM.from_pretrained(
|
||||
"./lora-llama-adapter",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
### Merge adapter into base model
|
||||
|
||||
```python
|
||||
# Merge for deployment (no adapter overhead)
|
||||
merged_model = model.merge_and_unload()
|
||||
|
||||
# Save merged model
|
||||
merged_model.save_pretrained("./llama-merged")
|
||||
tokenizer.save_pretrained("./llama-merged")
|
||||
|
||||
# Push to Hub
|
||||
merged_model.push_to_hub("username/llama-finetuned")
|
||||
```
|
||||
|
||||
### Multi-adapter serving
|
||||
|
||||
```python
|
||||
from peft import PeftModel
|
||||
|
||||
# Load base with first adapter
|
||||
model = AutoPeftModelForCausalLM.from_pretrained("./adapter-task1")
|
||||
|
||||
# Load additional adapters
|
||||
model.load_adapter("./adapter-task2", adapter_name="task2")
|
||||
model.load_adapter("./adapter-task3", adapter_name="task3")
|
||||
|
||||
# Switch between adapters at runtime
|
||||
model.set_adapter("task1") # Use task1 adapter
|
||||
output1 = model.generate(**inputs)
|
||||
|
||||
model.set_adapter("task2") # Switch to task2
|
||||
output2 = model.generate(**inputs)
|
||||
|
||||
# Disable adapters (use base model)
|
||||
with model.disable_adapter():
|
||||
base_output = model.generate(**inputs)
|
||||
```
|
||||
|
||||
## PEFT methods comparison
|
||||
|
||||
| Method | Trainable % | Memory | Speed | Best For |
|
||||
|--------|------------|--------|-------|----------|
|
||||
| **LoRA** | 0.1-1% | Low | Fast | General fine-tuning |
|
||||
| **QLoRA** | 0.1-1% | Very Low | Medium | Memory-constrained |
|
||||
| AdaLoRA | 0.1-1% | Low | Medium | Automatic rank selection |
|
||||
| IA3 | 0.01% | Minimal | Fastest | Few-shot adaptation |
|
||||
| Prefix Tuning | 0.1% | Low | Medium | Generation control |
|
||||
| Prompt Tuning | 0.001% | Minimal | Fast | Simple task adaptation |
|
||||
| P-Tuning v2 | 0.1% | Low | Medium | NLU tasks |
|
||||
|
||||
### IA3 (minimal parameters)
|
||||
|
||||
```python
|
||||
from peft import IA3Config
|
||||
|
||||
ia3_config = IA3Config(
|
||||
target_modules=["q_proj", "v_proj", "k_proj", "down_proj"],
|
||||
feedforward_modules=["down_proj"]
|
||||
)
|
||||
model = get_peft_model(model, ia3_config)
|
||||
# Trains only 0.01% of parameters!
|
||||
```
|
||||
|
||||
### Prefix Tuning
|
||||
|
||||
```python
|
||||
from peft import PrefixTuningConfig
|
||||
|
||||
prefix_config = PrefixTuningConfig(
|
||||
task_type="CAUSAL_LM",
|
||||
num_virtual_tokens=20, # Prepended tokens
|
||||
prefix_projection=True # Use MLP projection
|
||||
)
|
||||
model = get_peft_model(model, prefix_config)
|
||||
```
|
||||
|
||||
## Integration patterns
|
||||
|
||||
### With TRL (SFTTrainer)
|
||||
|
||||
```python
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
from peft import LoraConfig
|
||||
|
||||
lora_config = LoraConfig(r=16, lora_alpha=32, target_modules="all-linear")
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=SFTConfig(output_dir="./output", max_seq_length=512),
|
||||
train_dataset=dataset,
|
||||
peft_config=lora_config, # Pass LoRA config directly
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### With Axolotl (YAML config)
|
||||
|
||||
```yaml
|
||||
# axolotl config.yaml
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
lora_target_linear: true # Target all linear layers
|
||||
```
|
||||
|
||||
### With vLLM (inference)
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
# Load base model with LoRA support
|
||||
llm = LLM(model="meta-llama/Llama-3.1-8B", enable_lora=True)
|
||||
|
||||
# Serve with adapter
|
||||
outputs = llm.generate(
|
||||
prompts,
|
||||
lora_request=LoRARequest("adapter1", 1, "./lora-adapter")
|
||||
)
|
||||
```
|
||||
|
||||
## Performance benchmarks
|
||||
|
||||
### Memory usage (Llama 3.1 8B)
|
||||
|
||||
| Method | GPU Memory | Trainable Params |
|
||||
|--------|-----------|------------------|
|
||||
| Full fine-tuning | 60+ GB | 8B (100%) |
|
||||
| LoRA r=16 | 18 GB | 14M (0.17%) |
|
||||
| QLoRA r=16 | 6 GB | 14M (0.17%) |
|
||||
| IA3 | 16 GB | 800K (0.01%) |
|
||||
|
||||
### Training speed (A100 80GB)
|
||||
|
||||
| Method | Tokens/sec | vs Full FT |
|
||||
|--------|-----------|------------|
|
||||
| Full FT | 2,500 | 1x |
|
||||
| LoRA | 3,200 | 1.3x |
|
||||
| QLoRA | 2,100 | 0.84x |
|
||||
|
||||
### Quality (MMLU benchmark)
|
||||
|
||||
| Model | Full FT | LoRA | QLoRA |
|
||||
|-------|---------|------|-------|
|
||||
| Llama 2-7B | 45.3 | 44.8 | 44.1 |
|
||||
| Llama 2-13B | 54.8 | 54.2 | 53.5 |
|
||||
|
||||
## Common issues
|
||||
|
||||
### CUDA OOM during training
|
||||
|
||||
```python
|
||||
# Solution 1: Enable gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Solution 2: Reduce batch size + increase accumulation
|
||||
TrainingArguments(
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=16
|
||||
)
|
||||
|
||||
# Solution 3: Use QLoRA
|
||||
from transformers import BitsAndBytesConfig
|
||||
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4")
|
||||
```
|
||||
|
||||
### Adapter not applying
|
||||
|
||||
```python
|
||||
# Verify adapter is active
|
||||
print(model.active_adapters) # Should show adapter name
|
||||
|
||||
# Check trainable parameters
|
||||
model.print_trainable_parameters()
|
||||
|
||||
# Ensure model in training mode
|
||||
model.train()
|
||||
```
|
||||
|
||||
### Quality degradation
|
||||
|
||||
```python
|
||||
# Increase rank
|
||||
LoraConfig(r=32, lora_alpha=64)
|
||||
|
||||
# Target more modules
|
||||
target_modules = "all-linear"
|
||||
|
||||
# Use more training data and epochs
|
||||
TrainingArguments(num_train_epochs=5)
|
||||
|
||||
# Lower learning rate
|
||||
TrainingArguments(learning_rate=1e-4)
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Start with r=8-16**, increase if quality insufficient
|
||||
2. **Use alpha = 2 * rank** as starting point
|
||||
3. **Target attention + MLP layers** for best quality/efficiency
|
||||
4. **Enable gradient checkpointing** for memory savings
|
||||
5. **Save adapters frequently** (small files, easy rollback)
|
||||
6. **Evaluate on held-out data** before merging
|
||||
7. **Use QLoRA for 70B+ models** on consumer hardware
|
||||
|
||||
## References
|
||||
|
||||
- **[Advanced Usage](references/advanced-usage.md)** - DoRA, LoftQ, rank stabilization, custom modules
|
||||
- **[Troubleshooting](references/troubleshooting.md)** - Common errors, debugging, optimization
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/huggingface/peft
|
||||
- **Docs**: https://huggingface.co/docs/peft
|
||||
- **LoRA Paper**: arXiv:2106.09685
|
||||
- **QLoRA Paper**: arXiv:2305.14314
|
||||
- **Models**: https://huggingface.co/models?library=peft
|
||||
@@ -0,0 +1,514 @@
|
||||
# PEFT Advanced Usage Guide
|
||||
|
||||
## Advanced LoRA Variants
|
||||
|
||||
### DoRA (Weight-Decomposed Low-Rank Adaptation)
|
||||
|
||||
DoRA decomposes weights into magnitude and direction components, often achieving better results than standard LoRA:
|
||||
|
||||
```python
|
||||
from peft import LoraConfig
|
||||
|
||||
dora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
|
||||
use_dora=True, # Enable DoRA
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
model = get_peft_model(model, dora_config)
|
||||
```
|
||||
|
||||
**When to use DoRA**:
|
||||
- Consistently outperforms LoRA on instruction-following tasks
|
||||
- Slightly higher memory (~10%) due to magnitude vectors
|
||||
- Best for quality-critical fine-tuning
|
||||
|
||||
### AdaLoRA (Adaptive Rank)
|
||||
|
||||
Automatically adjusts rank per layer based on importance:
|
||||
|
||||
```python
|
||||
from peft import AdaLoraConfig
|
||||
|
||||
adalora_config = AdaLoraConfig(
|
||||
init_r=64, # Initial rank
|
||||
target_r=16, # Target average rank
|
||||
tinit=200, # Warmup steps
|
||||
tfinal=1000, # Final pruning step
|
||||
deltaT=10, # Rank update frequency
|
||||
beta1=0.85,
|
||||
beta2=0.85,
|
||||
orth_reg_weight=0.5, # Orthogonality regularization
|
||||
target_modules=["q_proj", "v_proj"],
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
```
|
||||
|
||||
**Benefits**:
|
||||
- Allocates more rank to important layers
|
||||
- Can reduce total parameters while maintaining quality
|
||||
- Good for exploring optimal rank distribution
|
||||
|
||||
### LoRA+ (Asymmetric Learning Rates)
|
||||
|
||||
Different learning rates for A and B matrices:
|
||||
|
||||
```python
|
||||
from peft import LoraConfig
|
||||
|
||||
# LoRA+ uses higher LR for B matrix
|
||||
lora_plus_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules="all-linear",
|
||||
use_rslora=True, # Rank-stabilized LoRA (related technique)
|
||||
)
|
||||
|
||||
# Manual implementation of LoRA+
|
||||
from torch.optim import AdamW
|
||||
|
||||
# Group parameters
|
||||
lora_A_params = [p for n, p in model.named_parameters() if "lora_A" in n]
|
||||
lora_B_params = [p for n, p in model.named_parameters() if "lora_B" in n]
|
||||
|
||||
optimizer = AdamW([
|
||||
{"params": lora_A_params, "lr": 1e-4},
|
||||
{"params": lora_B_params, "lr": 1e-3}, # 10x higher for B
|
||||
])
|
||||
```
|
||||
|
||||
### rsLoRA (Rank-Stabilized LoRA)
|
||||
|
||||
Scales LoRA outputs to stabilize training with different ranks:
|
||||
|
||||
```python
|
||||
lora_config = LoraConfig(
|
||||
r=64,
|
||||
lora_alpha=64,
|
||||
use_rslora=True, # Enables rank-stabilized scaling
|
||||
target_modules="all-linear"
|
||||
)
|
||||
```
|
||||
|
||||
**When to use**:
|
||||
- When experimenting with different ranks
|
||||
- Helps maintain consistent behavior across rank values
|
||||
- Recommended for r > 32
|
||||
|
||||
## LoftQ (LoRA-Fine-Tuning-aware Quantization)
|
||||
|
||||
Initializes LoRA weights to compensate for quantization error:
|
||||
|
||||
```python
|
||||
from peft import LoftQConfig, LoraConfig, get_peft_model
|
||||
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
|
||||
|
||||
# LoftQ configuration
|
||||
loftq_config = LoftQConfig(
|
||||
loftq_bits=4, # Quantization bits
|
||||
loftq_iter=5, # Alternating optimization iterations
|
||||
)
|
||||
|
||||
# LoRA config with LoftQ initialization
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules="all-linear",
|
||||
init_lora_weights="loftq",
|
||||
loftq_config=loftq_config,
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
# Load quantized model
|
||||
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.1-8B",
|
||||
quantization_config=bnb_config
|
||||
)
|
||||
|
||||
model = get_peft_model(model, lora_config)
|
||||
```
|
||||
|
||||
**Benefits over standard QLoRA**:
|
||||
- Better initial quality after quantization
|
||||
- Faster convergence
|
||||
- ~1-2% better final accuracy on benchmarks
|
||||
|
||||
## Custom Module Targeting
|
||||
|
||||
### Target specific layers
|
||||
|
||||
```python
|
||||
# Target only first and last transformer layers
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=["model.layers.0.self_attn.q_proj",
|
||||
"model.layers.0.self_attn.v_proj",
|
||||
"model.layers.31.self_attn.q_proj",
|
||||
"model.layers.31.self_attn.v_proj"],
|
||||
layers_to_transform=[0, 31] # Alternative approach
|
||||
)
|
||||
```
|
||||
|
||||
### Layer pattern matching
|
||||
|
||||
```python
|
||||
# Target layers 0-10 only
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules="all-linear",
|
||||
layers_to_transform=list(range(11)), # Layers 0-10
|
||||
layers_pattern="model.layers"
|
||||
)
|
||||
```
|
||||
|
||||
### Exclude specific layers
|
||||
|
||||
```python
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
target_modules="all-linear",
|
||||
modules_to_save=["lm_head"], # Train these fully (not LoRA)
|
||||
)
|
||||
```
|
||||
|
||||
## Embedding and LM Head Training
|
||||
|
||||
### Train embeddings with LoRA
|
||||
|
||||
```python
|
||||
from peft import LoraConfig
|
||||
|
||||
# Include embeddings
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=["q_proj", "v_proj", "embed_tokens"], # Include embeddings
|
||||
modules_to_save=["lm_head"], # Train lm_head fully
|
||||
)
|
||||
```
|
||||
|
||||
### Extending vocabulary with LoRA
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import get_peft_model, LoraConfig
|
||||
|
||||
# Add new tokens
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")
|
||||
new_tokens = ["<custom_token_1>", "<custom_token_2>"]
|
||||
tokenizer.add_tokens(new_tokens)
|
||||
|
||||
# Resize model embeddings
|
||||
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B")
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
# Configure LoRA to train new embeddings
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
target_modules="all-linear",
|
||||
modules_to_save=["embed_tokens", "lm_head"], # Train these fully
|
||||
)
|
||||
|
||||
model = get_peft_model(model, lora_config)
|
||||
```
|
||||
|
||||
## Multi-Adapter Patterns
|
||||
|
||||
### Adapter composition
|
||||
|
||||
```python
|
||||
from peft import PeftModel
|
||||
|
||||
# Load model with multiple adapters
|
||||
model = AutoPeftModelForCausalLM.from_pretrained("./base-adapter")
|
||||
model.load_adapter("./style-adapter", adapter_name="style")
|
||||
model.load_adapter("./task-adapter", adapter_name="task")
|
||||
|
||||
# Combine adapters (weighted sum)
|
||||
model.add_weighted_adapter(
|
||||
adapters=["style", "task"],
|
||||
weights=[0.7, 0.3],
|
||||
adapter_name="combined",
|
||||
combination_type="linear" # or "cat", "svd"
|
||||
)
|
||||
|
||||
model.set_adapter("combined")
|
||||
```
|
||||
|
||||
### Adapter stacking
|
||||
|
||||
```python
|
||||
# Stack adapters (apply sequentially)
|
||||
model.add_weighted_adapter(
|
||||
adapters=["base", "domain", "task"],
|
||||
weights=[1.0, 1.0, 1.0],
|
||||
adapter_name="stacked",
|
||||
combination_type="cat" # Concatenate adapter outputs
|
||||
)
|
||||
```
|
||||
|
||||
### Dynamic adapter switching
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
class MultiAdapterModel:
|
||||
def __init__(self, base_model_path, adapter_paths):
|
||||
self.model = AutoPeftModelForCausalLM.from_pretrained(adapter_paths[0])
|
||||
for name, path in adapter_paths[1:].items():
|
||||
self.model.load_adapter(path, adapter_name=name)
|
||||
|
||||
def generate(self, prompt, adapter_name="default"):
|
||||
self.model.set_adapter(adapter_name)
|
||||
return self.model.generate(**self.tokenize(prompt))
|
||||
|
||||
def generate_ensemble(self, prompt, adapters, weights):
|
||||
"""Generate with weighted adapter ensemble"""
|
||||
outputs = []
|
||||
for adapter, weight in zip(adapters, weights):
|
||||
self.model.set_adapter(adapter)
|
||||
logits = self.model(**self.tokenize(prompt)).logits
|
||||
outputs.append(weight * logits)
|
||||
return torch.stack(outputs).sum(dim=0)
|
||||
```
|
||||
|
||||
## Memory Optimization
|
||||
|
||||
### Gradient checkpointing with LoRA
|
||||
|
||||
```python
|
||||
from peft import prepare_model_for_kbit_training
|
||||
|
||||
# Enable gradient checkpointing
|
||||
model = prepare_model_for_kbit_training(
|
||||
model,
|
||||
use_gradient_checkpointing=True,
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||
)
|
||||
```
|
||||
|
||||
### CPU offloading for training
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision="bf16",
|
||||
gradient_accumulation_steps=8,
|
||||
cpu_offload=True # Offload optimizer states to CPU
|
||||
)
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
```
|
||||
|
||||
### Memory-efficient attention with LoRA
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
# Combine Flash Attention 2 with LoRA
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.1-8B",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
# Apply LoRA
|
||||
model = get_peft_model(model, lora_config)
|
||||
```
|
||||
|
||||
## Inference Optimization
|
||||
|
||||
### Merge for deployment
|
||||
|
||||
```python
|
||||
# Merge adapter weights into base model
|
||||
merged_model = model.merge_and_unload()
|
||||
|
||||
# Quantize merged model for inference
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
bnb_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
"./merged-model",
|
||||
quantization_config=bnb_config
|
||||
)
|
||||
```
|
||||
|
||||
### Export to different formats
|
||||
|
||||
```python
|
||||
# Export to GGUF (llama.cpp)
|
||||
# First merge, then convert
|
||||
merged_model.save_pretrained("./merged-model")
|
||||
|
||||
# Use llama.cpp converter
|
||||
# python convert-hf-to-gguf.py ./merged-model --outfile model.gguf
|
||||
|
||||
# Export to ONNX
|
||||
from optimum.onnxruntime import ORTModelForCausalLM
|
||||
|
||||
ort_model = ORTModelForCausalLM.from_pretrained(
|
||||
"./merged-model",
|
||||
export=True
|
||||
)
|
||||
ort_model.save_pretrained("./onnx-model")
|
||||
```
|
||||
|
||||
### Batch adapter inference
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
# Initialize with LoRA support
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.1-8B",
|
||||
enable_lora=True,
|
||||
max_lora_rank=64,
|
||||
max_loras=4 # Max concurrent adapters
|
||||
)
|
||||
|
||||
# Batch with different adapters
|
||||
requests = [
|
||||
("prompt1", LoRARequest("adapter1", 1, "./adapter1")),
|
||||
("prompt2", LoRARequest("adapter2", 2, "./adapter2")),
|
||||
("prompt3", LoRARequest("adapter1", 1, "./adapter1")),
|
||||
]
|
||||
|
||||
outputs = llm.generate(
|
||||
[r[0] for r in requests],
|
||||
lora_request=[r[1] for r in requests]
|
||||
)
|
||||
```
|
||||
|
||||
## Training Recipes
|
||||
|
||||
### Instruction tuning recipe
|
||||
|
||||
```python
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
target_modules="all-linear",
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./output",
|
||||
num_train_epochs=3,
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=2e-4,
|
||||
lr_scheduler_type="cosine",
|
||||
warmup_ratio=0.03,
|
||||
bf16=True,
|
||||
logging_steps=10,
|
||||
save_strategy="steps",
|
||||
save_steps=100,
|
||||
eval_strategy="steps",
|
||||
eval_steps=100,
|
||||
)
|
||||
```
|
||||
|
||||
### Code generation recipe
|
||||
|
||||
```python
|
||||
lora_config = LoraConfig(
|
||||
r=32, # Higher rank for code
|
||||
lora_alpha=64,
|
||||
lora_dropout=0.1,
|
||||
target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
training_args = TrainingArguments(
|
||||
learning_rate=1e-4, # Lower LR for code
|
||||
num_train_epochs=2,
|
||||
max_seq_length=2048, # Longer sequences
|
||||
)
|
||||
```
|
||||
|
||||
### Conversational/Chat recipe
|
||||
|
||||
```python
|
||||
from trl import SFTTrainer
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=16, # alpha = r for chat
|
||||
lora_dropout=0.05,
|
||||
target_modules="all-linear"
|
||||
)
|
||||
|
||||
# Use chat template
|
||||
def format_chat(example):
|
||||
messages = [
|
||||
{"role": "user", "content": example["instruction"]},
|
||||
{"role": "assistant", "content": example["response"]}
|
||||
]
|
||||
return tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
peft_config=lora_config,
|
||||
train_dataset=dataset.map(format_chat),
|
||||
max_seq_length=1024,
|
||||
)
|
||||
```
|
||||
|
||||
## Debugging and Validation
|
||||
|
||||
### Verify adapter application
|
||||
|
||||
```python
|
||||
# Check which modules have LoRA
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, "lora_A"):
|
||||
print(f"LoRA applied to: {name}")
|
||||
|
||||
# Print detailed config
|
||||
print(model.peft_config)
|
||||
|
||||
# Check adapter state
|
||||
print(f"Active adapters: {model.active_adapters}")
|
||||
print(f"Trainable: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
```
|
||||
|
||||
### Compare with base model
|
||||
|
||||
```python
|
||||
# Generate with adapter
|
||||
model.set_adapter("default")
|
||||
adapter_output = model.generate(**inputs)
|
||||
|
||||
# Generate without adapter
|
||||
with model.disable_adapter():
|
||||
base_output = model.generate(**inputs)
|
||||
|
||||
print(f"Adapter: {tokenizer.decode(adapter_output[0])}")
|
||||
print(f"Base: {tokenizer.decode(base_output[0])}")
|
||||
```
|
||||
|
||||
### Monitor training metrics
|
||||
|
||||
```python
|
||||
from transformers import TrainerCallback
|
||||
|
||||
class LoRACallback(TrainerCallback):
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
if "loss" in logs:
|
||||
# Log adapter-specific metrics
|
||||
model = kwargs["model"]
|
||||
lora_params = sum(p.numel() for n, p in model.named_parameters()
|
||||
if "lora" in n and p.requires_grad)
|
||||
print(f"Step {state.global_step}: loss={logs['loss']:.4f}, lora_params={lora_params}")
|
||||
```
|
||||
@@ -0,0 +1,480 @@
|
||||
# PEFT Troubleshooting Guide
|
||||
|
||||
## Installation Issues
|
||||
|
||||
### bitsandbytes CUDA Error
|
||||
|
||||
**Error**: `CUDA Setup failed despite GPU being available`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Check CUDA version
|
||||
nvcc --version
|
||||
|
||||
# Install matching bitsandbytes
|
||||
pip uninstall bitsandbytes
|
||||
pip install bitsandbytes --no-cache-dir
|
||||
|
||||
# Or compile from source for specific CUDA
|
||||
git clone https://github.com/TimDettmers/bitsandbytes.git
|
||||
cd bitsandbytes
|
||||
CUDA_VERSION=118 make cuda11x # Adjust for your CUDA
|
||||
pip install .
|
||||
```
|
||||
|
||||
### Triton Import Error
|
||||
|
||||
**Error**: `ModuleNotFoundError: No module named 'triton'`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Install triton (Linux only)
|
||||
pip install triton
|
||||
|
||||
# Windows: Triton not supported, use CUDA backend
|
||||
# Set environment variable to disable triton
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
```
|
||||
|
||||
### PEFT Version Conflicts
|
||||
|
||||
**Error**: `AttributeError: 'LoraConfig' object has no attribute 'use_dora'`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Upgrade to latest PEFT
|
||||
pip install peft>=0.13.0 --upgrade
|
||||
|
||||
# Check version
|
||||
python -c "import peft; print(peft.__version__)"
|
||||
```
|
||||
|
||||
## Training Issues
|
||||
|
||||
### CUDA Out of Memory
|
||||
|
||||
**Error**: `torch.cuda.OutOfMemoryError: CUDA out of memory`
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Enable gradient checkpointing**:
|
||||
```python
|
||||
from peft import prepare_model_for_kbit_training
|
||||
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
|
||||
```
|
||||
|
||||
2. **Reduce batch size**:
|
||||
```python
|
||||
TrainingArguments(
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=16 # Maintain effective batch size
|
||||
)
|
||||
```
|
||||
|
||||
3. **Use QLoRA**:
|
||||
```python
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_use_double_quant=True
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config)
|
||||
```
|
||||
|
||||
4. **Lower LoRA rank**:
|
||||
```python
|
||||
LoraConfig(r=8) # Instead of r=16 or higher
|
||||
```
|
||||
|
||||
5. **Target fewer modules**:
|
||||
```python
|
||||
target_modules=["q_proj", "v_proj"] # Instead of all-linear
|
||||
```
|
||||
|
||||
### Loss Not Decreasing
|
||||
|
||||
**Problem**: Training loss stays flat or increases.
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Check learning rate**:
|
||||
```python
|
||||
# Start lower
|
||||
TrainingArguments(learning_rate=1e-4) # Not 2e-4 or higher
|
||||
```
|
||||
|
||||
2. **Verify adapter is active**:
|
||||
```python
|
||||
model.print_trainable_parameters()
|
||||
# Should show >0 trainable params
|
||||
|
||||
# Check adapter applied
|
||||
print(model.peft_config)
|
||||
```
|
||||
|
||||
3. **Check data formatting**:
|
||||
```python
|
||||
# Verify tokenization
|
||||
sample = dataset[0]
|
||||
decoded = tokenizer.decode(sample["input_ids"])
|
||||
print(decoded) # Should look correct
|
||||
```
|
||||
|
||||
4. **Increase rank**:
|
||||
```python
|
||||
LoraConfig(r=32, lora_alpha=64) # More capacity
|
||||
```
|
||||
|
||||
### NaN Loss
|
||||
|
||||
**Error**: `Loss is NaN`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Use bf16 instead of fp16
|
||||
TrainingArguments(bf16=True, fp16=False)
|
||||
|
||||
# Or enable loss scaling
|
||||
TrainingArguments(fp16=True, fp16_full_eval=True)
|
||||
|
||||
# Lower learning rate
|
||||
TrainingArguments(learning_rate=5e-5)
|
||||
|
||||
# Check for data issues
|
||||
for batch in dataloader:
|
||||
if torch.isnan(batch["input_ids"].float()).any():
|
||||
print("NaN in input!")
|
||||
```
|
||||
|
||||
### Adapter Not Training
|
||||
|
||||
**Problem**: `trainable params: 0` or model not updating.
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Verify LoRA applied to correct modules
|
||||
for name, module in model.named_modules():
|
||||
if "lora" in name.lower():
|
||||
print(f"Found LoRA: {name}")
|
||||
|
||||
# Check target_modules match model architecture
|
||||
from peft.utils import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
|
||||
print(TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.get(model.config.model_type))
|
||||
|
||||
# Ensure model in training mode
|
||||
model.train()
|
||||
|
||||
# Check requires_grad
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
print(f"Trainable: {name}")
|
||||
```
|
||||
|
||||
## Loading Issues
|
||||
|
||||
### Adapter Loading Fails
|
||||
|
||||
**Error**: `ValueError: Can't find adapter weights`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Check adapter files exist
|
||||
import os
|
||||
print(os.listdir("./adapter-path"))
|
||||
# Should contain: adapter_config.json, adapter_model.safetensors
|
||||
|
||||
# Load with correct structure
|
||||
from peft import PeftModel, PeftConfig
|
||||
|
||||
# Check config
|
||||
config = PeftConfig.from_pretrained("./adapter-path")
|
||||
print(config)
|
||||
|
||||
# Load base model first
|
||||
base_model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
|
||||
model = PeftModel.from_pretrained(base_model, "./adapter-path")
|
||||
```
|
||||
|
||||
### Base Model Mismatch
|
||||
|
||||
**Error**: `RuntimeError: size mismatch`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Ensure base model matches adapter
|
||||
from peft import PeftConfig
|
||||
|
||||
config = PeftConfig.from_pretrained("./adapter-path")
|
||||
print(f"Base model: {config.base_model_name_or_path}")
|
||||
|
||||
# Load exact same base model
|
||||
base_model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
|
||||
```
|
||||
|
||||
### Safetensors vs PyTorch Format
|
||||
|
||||
**Error**: `ValueError: We couldn't connect to 'https://huggingface.co'`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Force local loading
|
||||
model = PeftModel.from_pretrained(
|
||||
base_model,
|
||||
"./adapter-path",
|
||||
local_files_only=True
|
||||
)
|
||||
|
||||
# Or specify format
|
||||
model.save_pretrained("./adapter", safe_serialization=True) # safetensors
|
||||
model.save_pretrained("./adapter", safe_serialization=False) # pytorch
|
||||
```
|
||||
|
||||
## Inference Issues
|
||||
|
||||
### Slow Generation
|
||||
|
||||
**Problem**: Inference much slower than expected.
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Merge adapter for deployment**:
|
||||
```python
|
||||
merged_model = model.merge_and_unload()
|
||||
# No adapter overhead during inference
|
||||
```
|
||||
|
||||
2. **Use optimized inference engine**:
|
||||
```python
|
||||
from vllm import LLM
|
||||
llm = LLM(model="./merged-model", dtype="half")
|
||||
```
|
||||
|
||||
3. **Enable Flash Attention**:
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
attn_implementation="flash_attention_2"
|
||||
)
|
||||
```
|
||||
|
||||
### Output Quality Issues
|
||||
|
||||
**Problem**: Fine-tuned model produces worse outputs.
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Check evaluation without adapter**:
|
||||
```python
|
||||
with model.disable_adapter():
|
||||
base_output = model.generate(**inputs)
|
||||
# Compare with adapter output
|
||||
```
|
||||
|
||||
2. **Lower temperature during eval**:
|
||||
```python
|
||||
model.generate(**inputs, temperature=0.1, do_sample=False)
|
||||
```
|
||||
|
||||
3. **Retrain with more data**:
|
||||
```python
|
||||
# Increase training samples
|
||||
# Use higher quality data
|
||||
# Train for more epochs
|
||||
```
|
||||
|
||||
### Wrong Adapter Active
|
||||
|
||||
**Problem**: Model using wrong adapter or no adapter.
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Check active adapters
|
||||
print(model.active_adapters)
|
||||
|
||||
# Explicitly set adapter
|
||||
model.set_adapter("your-adapter-name")
|
||||
|
||||
# List all adapters
|
||||
print(model.peft_config.keys())
|
||||
```
|
||||
|
||||
## QLoRA Specific Issues
|
||||
|
||||
### Quantization Errors
|
||||
|
||||
**Error**: `RuntimeError: mat1 and mat2 shapes cannot be multiplied`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Ensure compute dtype matches
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=torch.bfloat16, # Match model dtype
|
||||
bnb_4bit_quant_type="nf4"
|
||||
)
|
||||
|
||||
# Load with correct dtype
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
quantization_config=bnb_config,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
```
|
||||
|
||||
### QLoRA OOM
|
||||
|
||||
**Error**: OOM even with 4-bit quantization.
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Enable double quantization
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True # Further memory reduction
|
||||
)
|
||||
|
||||
# Use offloading
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
quantization_config=bnb_config,
|
||||
device_map="auto",
|
||||
max_memory={0: "20GB", "cpu": "100GB"}
|
||||
)
|
||||
```
|
||||
|
||||
### QLoRA Merge Fails
|
||||
|
||||
**Error**: `RuntimeError: expected scalar type BFloat16 but found Float`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Dequantize before merging
|
||||
from peft import PeftModel
|
||||
|
||||
# Load in higher precision for merging
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model_name,
|
||||
torch_dtype=torch.float16, # Not quantized
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
# Load adapter
|
||||
model = PeftModel.from_pretrained(base_model, "./qlora-adapter")
|
||||
|
||||
# Now merge
|
||||
merged = model.merge_and_unload()
|
||||
```
|
||||
|
||||
## Multi-Adapter Issues
|
||||
|
||||
### Adapter Conflict
|
||||
|
||||
**Error**: `ValueError: Adapter with name 'default' already exists`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Use unique names
|
||||
model.load_adapter("./adapter1", adapter_name="task1")
|
||||
model.load_adapter("./adapter2", adapter_name="task2")
|
||||
|
||||
# Or delete existing
|
||||
model.delete_adapter("default")
|
||||
```
|
||||
|
||||
### Mixed Precision Adapters
|
||||
|
||||
**Error**: Adapters trained with different dtypes.
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Convert adapter precision
|
||||
model = PeftModel.from_pretrained(base_model, "./adapter")
|
||||
model = model.to(torch.bfloat16)
|
||||
|
||||
# Or load with specific dtype
|
||||
model = PeftModel.from_pretrained(
|
||||
base_model,
|
||||
"./adapter",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### Memory Profiling
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
def print_memory():
|
||||
if torch.cuda.is_available():
|
||||
allocated = torch.cuda.memory_allocated() / 1e9
|
||||
reserved = torch.cuda.memory_reserved() / 1e9
|
||||
print(f"Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB")
|
||||
|
||||
# Profile during training
|
||||
print_memory() # Before
|
||||
model.train()
|
||||
loss = model(**batch).loss
|
||||
loss.backward()
|
||||
print_memory() # After
|
||||
```
|
||||
|
||||
### Speed Profiling
|
||||
|
||||
```python
|
||||
import time
|
||||
import torch
|
||||
|
||||
def benchmark_generation(model, tokenizer, prompt, n_runs=5):
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||
|
||||
# Warmup
|
||||
model.generate(**inputs, max_new_tokens=10)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Benchmark
|
||||
times = []
|
||||
for _ in range(n_runs):
|
||||
start = time.perf_counter()
|
||||
outputs = model.generate(**inputs, max_new_tokens=100)
|
||||
torch.cuda.synchronize()
|
||||
times.append(time.perf_counter() - start)
|
||||
|
||||
tokens = outputs.shape[1] - inputs.input_ids.shape[1]
|
||||
avg_time = sum(times) / len(times)
|
||||
print(f"Speed: {tokens/avg_time:.2f} tokens/sec")
|
||||
|
||||
# Compare adapter vs merged
|
||||
benchmark_generation(adapter_model, tokenizer, "Hello")
|
||||
benchmark_generation(merged_model, tokenizer, "Hello")
|
||||
```
|
||||
|
||||
## Getting Help
|
||||
|
||||
1. **Check PEFT GitHub Issues**: https://github.com/huggingface/peft/issues
|
||||
2. **HuggingFace Forums**: https://discuss.huggingface.co/
|
||||
3. **PEFT Documentation**: https://huggingface.co/docs/peft
|
||||
|
||||
### Debugging Template
|
||||
|
||||
When reporting issues, include:
|
||||
|
||||
```python
|
||||
# System info
|
||||
import peft
|
||||
import transformers
|
||||
import torch
|
||||
|
||||
print(f"PEFT: {peft.__version__}")
|
||||
print(f"Transformers: {transformers.__version__}")
|
||||
print(f"PyTorch: {torch.__version__}")
|
||||
print(f"CUDA: {torch.version.cuda}")
|
||||
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'N/A'}")
|
||||
|
||||
# Config
|
||||
print(model.peft_config)
|
||||
model.print_trainable_parameters()
|
||||
```
|
||||
@@ -0,0 +1,80 @@
|
||||
---
|
||||
name: unsloth
|
||||
description: Expert guidance for fast fine-tuning with Unsloth - 2-5x faster training, 50-80% less memory, LoRA/QLoRA optimization
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Fine-Tuning, Unsloth, Fast Training, LoRA, QLoRA, Memory-Efficient, Optimization, Llama, Mistral, Gemma, Qwen]
|
||||
dependencies: [unsloth, torch, transformers, trl, datasets, peft]
|
||||
---
|
||||
|
||||
# Unsloth Skill
|
||||
|
||||
Comprehensive assistance with unsloth development, generated from official documentation.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
This skill should be triggered when:
|
||||
- Working with unsloth
|
||||
- Asking about unsloth features or APIs
|
||||
- Implementing unsloth solutions
|
||||
- Debugging unsloth code
|
||||
- Learning unsloth best practices
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### Common Patterns
|
||||
|
||||
*Quick reference patterns will be added as you use the skill.*
|
||||
|
||||
## Reference Files
|
||||
|
||||
This skill includes comprehensive documentation in `references/`:
|
||||
|
||||
- **llms-txt.md** - Llms-Txt documentation
|
||||
|
||||
Use `view` to read specific reference files when detailed information is needed.
|
||||
|
||||
## Working with This Skill
|
||||
|
||||
### For Beginners
|
||||
Start with the getting_started or tutorials reference files for foundational concepts.
|
||||
|
||||
### For Specific Features
|
||||
Use the appropriate category reference file (api, guides, etc.) for detailed information.
|
||||
|
||||
### For Code Examples
|
||||
The quick reference section above contains common patterns extracted from the official docs.
|
||||
|
||||
## Resources
|
||||
|
||||
### references/
|
||||
Organized documentation extracted from official sources. These files contain:
|
||||
- Detailed explanations
|
||||
- Code examples with language annotations
|
||||
- Links to original documentation
|
||||
- Table of contents for quick navigation
|
||||
|
||||
### scripts/
|
||||
Add helper scripts here for common automation tasks.
|
||||
|
||||
### assets/
|
||||
Add templates, boilerplate, or example projects here.
|
||||
|
||||
## Notes
|
||||
|
||||
- This skill was automatically generated from official documentation
|
||||
- Reference files preserve the structure and examples from source docs
|
||||
- Code examples include language detection for better syntax highlighting
|
||||
- Quick reference patterns are extracted from common usage examples in the docs
|
||||
|
||||
## Updating
|
||||
|
||||
To refresh this skill with updated documentation:
|
||||
1. Re-run the scraper with the same configuration
|
||||
2. The skill will be rebuilt with the latest information
|
||||
|
||||
<!-- Trigger re-upload 1763621536 -->
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
# Unsloth Documentation Index
|
||||
|
||||
## Categories
|
||||
|
||||
### Llms-Txt
|
||||
**File:** `llms-txt.md`
|
||||
**Pages:** 136
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -0,0 +1,82 @@
|
||||
# Unsloth Documentation
|
||||
|
||||
## Unsloth Documentation
|
||||
|
||||
- [Unsloth Docs](/get-started/unsloth-docs.md): Train your own model with Unsloth, an open-source framework for LLM fine-tuning and reinforcement learning.
|
||||
- [Beginner? Start here!](/get-started/beginner-start-here.md)
|
||||
- [Unsloth Requirements](/get-started/beginner-start-here/unsloth-requirements.md): Here are Unsloth's requirements including system and GPU VRAM requirements.
|
||||
- [FAQ + Is Fine-tuning Right For Me?](/get-started/beginner-start-here/faq-+-is-fine-tuning-right-for-me.md): If you're stuck on if fine-tuning is right for you, see here! Learn about fine-tuning misconceptions, how it compared to RAG and more:
|
||||
- [Unsloth Notebooks](/get-started/unsloth-notebooks.md): Explore our catalog of Unsloth notebooks:
|
||||
- [All Our Models](/get-started/all-our-models.md)
|
||||
- [Install & Update](/get-started/install-and-update.md): Learn to install Unsloth locally or online.
|
||||
- [Updating](/get-started/install-and-update/updating.md): To update or use an old version of Unsloth, follow the steps below:
|
||||
- [Pip Install](/get-started/install-and-update/pip-install.md): To install Unsloth locally via Pip, follow the steps below:
|
||||
- [Docker](/get-started/install-and-update/docker.md): Install Unsloth using our official Docker container
|
||||
- [Windows Installation](/get-started/install-and-update/windows-installation.md): See how to install Unsloth on Windows with or without WSL.
|
||||
- [AMD](/get-started/install-and-update/amd.md): Fine-tune with Unsloth on AMD GPUs.
|
||||
- [Conda Install](/get-started/install-and-update/conda-install.md): To install Unsloth locally on Conda, follow the steps below:
|
||||
- [Google Colab](/get-started/install-and-update/google-colab.md): To install and run Unsloth on Google Colab, follow the steps below:
|
||||
- [Fine-tuning LLMs Guide](/get-started/fine-tuning-llms-guide.md): Learn all the basics and best practices of fine-tuning. Beginner-friendly.
|
||||
- [What Model Should I Use?](/get-started/fine-tuning-llms-guide/what-model-should-i-use.md)
|
||||
- [Datasets Guide](/get-started/fine-tuning-llms-guide/datasets-guide.md): Learn how to create & prepare a dataset for fine-tuning.
|
||||
- [LoRA Hyperparameters Guide](/get-started/fine-tuning-llms-guide/lora-hyperparameters-guide.md): Optimal lora rank. alpha, number of epochs, batch size & gradient accumulation, QLoRA vs LoRA, target modules and more!
|
||||
- [Tutorial: How to Finetune Llama-3 and Use In Ollama](/get-started/fine-tuning-llms-guide/tutorial-how-to-finetune-llama-3-and-use-in-ollama.md): Beginner's Guide for creating a customized personal assistant (like ChatGPT) to run locally on Ollama
|
||||
- [Reinforcement Learning (RL) Guide](/get-started/reinforcement-learning-rl-guide.md): Learn all about Reinforcement Learning (RL) and how to train your own DeepSeek-R1 reasoning model with Unsloth using GRPO. A complete guide from beginner to advanced.
|
||||
- [Tutorial: Train your own Reasoning model with GRPO](/get-started/reinforcement-learning-rl-guide/tutorial-train-your-own-reasoning-model-with-grpo.md): Beginner's Guide to transforming a model like Llama 3.1 (8B) into a reasoning model by using Unsloth and GRPO.
|
||||
- [Advanced RL Documentation](/get-started/reinforcement-learning-rl-guide/advanced-rl-documentation.md): Advanced documentation settings when using Unsloth with GRPO.
|
||||
- [Memory Efficient RL](/get-started/reinforcement-learning-rl-guide/memory-efficient-rl.md)
|
||||
- [RL Reward Hacking](/get-started/reinforcement-learning-rl-guide/rl-reward-hacking.md): Learn what is Reward Hacking in Reinforcement Learning and how to counter it.
|
||||
- [GSPO Reinforcement Learning](/get-started/reinforcement-learning-rl-guide/gspo-reinforcement-learning.md): Train with GSPO (Group Sequence Policy Optimization) RL in Unsloth.
|
||||
- [Reinforcement Learning - DPO, ORPO & KTO](/get-started/reinforcement-learning-rl-guide/reinforcement-learning-dpo-orpo-and-kto.md): To use the reward modelling functions for DPO, GRPO, ORPO or KTO with Unsloth, follow the steps below:
|
||||
- [DeepSeek-OCR: How to Run & Fine-tune](/new/deepseek-ocr-how-to-run-and-fine-tune.md): Guide on how to run and fine-tune DeepSeek-OCR locally.
|
||||
- [How to Fine-tune LLMs with Unsloth & Docker](/new/how-to-fine-tune-llms-with-unsloth-and-docker.md): Learn how to fine-tune LLMs or do Reinforcement Learning (RL) with Unsloth's Docker image.
|
||||
- [Vision Reinforcement Learning (VLM RL)](/new/vision-reinforcement-learning-vlm-rl.md): Train Vision/multimodal models via GRPO and RL with Unsloth!
|
||||
- [gpt-oss Reinforcement Learning](/new/gpt-oss-reinforcement-learning.md)
|
||||
- [Tutorial: How to Train gpt-oss with RL](/new/gpt-oss-reinforcement-learning/tutorial-how-to-train-gpt-oss-with-rl.md): Learn to train OpenAI gpt-oss with GRPO to autonomously beat 2048 locally or on Colab.
|
||||
- [Unsloth Dynamic GGUFs on Aider Polyglot](/new/unsloth-dynamic-ggufs-on-aider-polyglot.md): Performance of Unsloth Dynamic GGUFs on Aider Polyglot Benchmarks
|
||||
- [Qwen3-VL: How to Run & Fine-tune](/models/qwen3-vl-how-to-run-and-fine-tune.md): Learn to fine-tune and run Qwen3-VL locally with Unsloth.
|
||||
- [gpt-oss: How to Run & Fine-tune](/models/gpt-oss-how-to-run-and-fine-tune.md): Run & fine-tune OpenAI's new open-source models!
|
||||
- [Tutorial: How to Fine-tune gpt-oss](/models/gpt-oss-how-to-run-and-fine-tune/tutorial-how-to-fine-tune-gpt-oss.md): Learn step-by-step how to train OpenAI gpt-oss locally with Unsloth.
|
||||
- [Long Context gpt-oss Training](/models/gpt-oss-how-to-run-and-fine-tune/long-context-gpt-oss-training.md)
|
||||
- [GLM-4.6: How to Run Locally](/models/glm-4.6-how-to-run-locally.md): A guide on how to run Z.ai's new GLM-4.6 model on your own local device!
|
||||
- [IBM Granite 4.0](/models/ibm-granite-4.0.md): How to run IBM Granite-4.0 with Unsloth GGUFs on llama.cpp, Ollama and how to fine-tune!
|
||||
- [DeepSeek-V3.1: How to Run Locally](/models/deepseek-v3.1-how-to-run-locally.md): A guide on how to run DeepSeek-V3.1 and Terminus on your own local device!
|
||||
- [Qwen3-Coder: How to Run Locally](/models/qwen3-coder-how-to-run-locally.md): Run Qwen3-Coder-30B-A3B-Instruct and 480B-A35B locally with Unsloth Dynamic quants.
|
||||
- [Gemma 3: How to Run & Fine-tune](/models/gemma-3-how-to-run-and-fine-tune.md): How to run Gemma 3 effectively with our GGUFs on llama.cpp, Ollama, Open WebUI and how to fine-tune with Unsloth!
|
||||
- [Gemma 3n: How to Run & Fine-tune](/models/gemma-3-how-to-run-and-fine-tune/gemma-3n-how-to-run-and-fine-tune.md): Run Google's new Gemma 3n locally with Dynamic GGUFs on llama.cpp, Ollama, Open WebUI and fine-tune with Unsloth!
|
||||
- [Qwen3: How to Run & Fine-tune](/models/qwen3-how-to-run-and-fine-tune.md): Learn to run & fine-tune Qwen3 locally with Unsloth + our Dynamic 2.0 quants
|
||||
- [Qwen3-2507](/models/qwen3-how-to-run-and-fine-tune/qwen3-2507.md): Run Qwen3-30B-A3B-2507 and 235B-A22B Thinking and Instruct versions locally on your device!
|
||||
- [Tutorials: How To Fine-tune & Run LLMs](/models/tutorials-how-to-fine-tune-and-run-llms.md): Learn how to run and fine-tune models for optimal performance 100% locally with Unsloth.
|
||||
- [DeepSeek-R1-0528: How to Run Locally](/models/tutorials-how-to-fine-tune-and-run-llms/deepseek-r1-0528-how-to-run-locally.md): A guide on how to run DeepSeek-R1-0528 including Qwen3 on your own local device!
|
||||
- [Magistral: How to Run & Fine-tune](/models/tutorials-how-to-fine-tune-and-run-llms/magistral-how-to-run-and-fine-tune.md): Meet Magistral - Mistral's new reasoning models.
|
||||
- [Llama 4: How to Run & Fine-tune](/models/tutorials-how-to-fine-tune-and-run-llms/llama-4-how-to-run-and-fine-tune.md): How to run Llama 4 locally using our dynamic GGUFs which recovers accuracy compared to standard quantization.
|
||||
- [Kimi K2: How to Run Locally](/models/tutorials-how-to-fine-tune-and-run-llms/kimi-k2-how-to-run-locally.md): Guide on running Kimi K2 and Kimi-K2-Instruct-0905 on your own local device!
|
||||
- [Grok 2](/models/tutorials-how-to-fine-tune-and-run-llms/grok-2.md): Run xAI's Grok 2 model locally!
|
||||
- [Devstral: How to Run & Fine-tune](/models/tutorials-how-to-fine-tune-and-run-llms/devstral-how-to-run-and-fine-tune.md): Run and fine-tune Mistral Devstral 1.1, including Small-2507 and 2505.
|
||||
- [DeepSeek-V3-0324: How to Run Locally](/models/tutorials-how-to-fine-tune-and-run-llms/deepseek-v3-0324-how-to-run-locally.md): How to run DeepSeek-V3-0324 locally using our dynamic quants which recovers accuracy
|
||||
- [DeepSeek-R1: How to Run Locally](/models/tutorials-how-to-fine-tune-and-run-llms/deepseek-r1-how-to-run-locally.md): A guide on how you can run our 1.58-bit Dynamic Quants for DeepSeek-R1 using llama.cpp.
|
||||
- [DeepSeek-R1 Dynamic 1.58-bit](/models/tutorials-how-to-fine-tune-and-run-llms/deepseek-r1-how-to-run-locally/deepseek-r1-dynamic-1.58-bit.md): See performance comparison tables for Unsloth's Dynamic GGUF Quants vs Standard IMatrix Quants.
|
||||
- [QwQ-32B: How to Run effectively](/models/tutorials-how-to-fine-tune-and-run-llms/qwq-32b-how-to-run-effectively.md): How to run QwQ-32B effectively with our bug fixes and without endless generations + GGUFs.
|
||||
- [Phi-4 Reasoning: How to Run & Fine-tune](/models/tutorials-how-to-fine-tune-and-run-llms/phi-4-reasoning-how-to-run-and-fine-tune.md): Learn to run & fine-tune Phi-4 reasoning models locally with Unsloth + our Dynamic 2.0 quants
|
||||
- [Running & Saving Models](/basics/running-and-saving-models.md): Learn how to save your finetuned model so you can run it in your favorite inference engine.
|
||||
- [Saving to GGUF](/basics/running-and-saving-models/saving-to-gguf.md): Saving models to 16bit for GGUF so you can use it for Ollama, Jan AI, Open WebUI and more!
|
||||
- [Saving to Ollama](/basics/running-and-saving-models/saving-to-ollama.md)
|
||||
- [Saving to vLLM for deployment](/basics/running-and-saving-models/saving-to-vllm-for-deployment.md): Saving models to 16bit for vLLM deployment and serving
|
||||
- [Saving to SGLang for deployment](/basics/running-and-saving-models/saving-to-sglang-for-deployment.md): Saving models to 16bit for SGLang for deployment and serving
|
||||
- [Unsloth Inference](/basics/running-and-saving-models/unsloth-inference.md): Learn how to run your finetuned model with Unsloth's faster inference.
|
||||
- [Troubleshooting Inference](/basics/running-and-saving-models/troubleshooting-inference.md): If you're experiencing issues when running or saving your model.
|
||||
- [vLLM Engine Arguments](/basics/running-and-saving-models/vllm-engine-arguments.md)
|
||||
- [LoRA Hot Swapping Guide](/basics/running-and-saving-models/lora-hot-swapping-guide.md)
|
||||
- [Text-to-Speech (TTS) Fine-tuning](/basics/text-to-speech-tts-fine-tuning.md): Learn how to to fine-tune TTS & STT voice models with Unsloth.
|
||||
- [Unsloth Dynamic 2.0 GGUFs](/basics/unsloth-dynamic-2.0-ggufs.md): A big new upgrade to our Dynamic Quants!
|
||||
- [Vision Fine-tuning](/basics/vision-fine-tuning.md): Learn how to fine-tune vision/multimodal LLMs with Unsloth
|
||||
- [Fine-tuning LLMs with NVIDIA DGX Spark and Unsloth](/basics/fine-tuning-llms-with-nvidia-dgx-spark-and-unsloth.md): Tutorial on how to fine-tune and do reinforcement learning (RL) with OpenAI gpt-oss on NVIDIA DGX Spark.
|
||||
- [Fine-tuning LLMs with Blackwell, RTX 50 series & Unsloth](/basics/fine-tuning-llms-with-blackwell-rtx-50-series-and-unsloth.md): Learn how to fine-tune LLMs on NVIDIA's Blackwell RTX 50 series and B200 GPUs with our step-by-step guide.
|
||||
- [Multi-GPU Training with Unsloth](/basics/multi-gpu-training-with-unsloth.md): Learn how to fine-tune LLMs on multiple GPUs and parallelism with Unsloth.
|
||||
- [Finetuning from Last Checkpoint](/basics/finetuning-from-last-checkpoint.md): Checkpointing allows you to save your finetuning progress so you can pause it and then continue.
|
||||
- [Troubleshooting & FAQs](/basics/troubleshooting-and-faqs.md): Tips to solve issues, and frequently asked questions.
|
||||
- [Chat Templates](/basics/chat-templates.md): Learn the fundamentals and customization options of chat templates, including Conversational, ChatML, ShareGPT, Alpaca formats, and more!
|
||||
- [Quantization-Aware Training (QAT)](/basics/quantization-aware-training-qat.md): Quantize models to 4-bit with Unsloth and PyTorch to recover accuracy.
|
||||
- [Unsloth Environment Flags](/basics/unsloth-environment-flags.md): Advanced flags which might be useful if you see breaking finetunes, or you want to turn stuff off.
|
||||
- [Continued Pretraining](/basics/continued-pretraining.md): AKA as Continued Finetuning. Unsloth allows you to continually pretrain so a model can learn a new language.
|
||||
- [Unsloth Benchmarks](/basics/unsloth-benchmarks.md): Unsloth recorded benchmarks on NVIDIA GPUs.
|
||||
@@ -0,0 +1,436 @@
|
||||
---
|
||||
name: nnsight-remote-interpretability
|
||||
description: Provides guidance for interpreting and manipulating neural network internals using nnsight with optional NDIF remote execution. Use when needing to run interpretability experiments on massive models (70B+) without local GPU resources, or when working with any PyTorch architecture.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [nnsight, NDIF, Remote Execution, Mechanistic Interpretability, Model Internals]
|
||||
dependencies: [nnsight>=0.5.0, torch>=2.0.0]
|
||||
---
|
||||
|
||||
# nnsight: Transparent Access to Neural Network Internals
|
||||
|
||||
nnsight (/ɛn.saɪt/) enables researchers to interpret and manipulate the internals of any PyTorch model, with the unique capability of running the same code locally on small models or remotely on massive models (70B+) via NDIF.
|
||||
|
||||
**GitHub**: [ndif-team/nnsight](https://github.com/ndif-team/nnsight) (730+ stars)
|
||||
**Paper**: [NNsight and NDIF: Democratizing Access to Foundation Model Internals](https://arxiv.org/abs/2407.14561) (ICLR 2025)
|
||||
|
||||
## Key Value Proposition
|
||||
|
||||
**Write once, run anywhere**: The same interpretability code works on GPT-2 locally or Llama-3.1-405B remotely. Just toggle `remote=True`.
|
||||
|
||||
```python
|
||||
# Local execution (small model)
|
||||
with model.trace("Hello world"):
|
||||
hidden = model.transformer.h[5].output[0].save()
|
||||
|
||||
# Remote execution (massive model) - same code!
|
||||
with model.trace("Hello world", remote=True):
|
||||
hidden = model.model.layers[40].output[0].save()
|
||||
```
|
||||
|
||||
## When to Use nnsight
|
||||
|
||||
**Use nnsight when you need to:**
|
||||
- Run interpretability experiments on models too large for local GPUs (70B, 405B)
|
||||
- Work with any PyTorch architecture (transformers, Mamba, custom models)
|
||||
- Perform multi-token generation interventions
|
||||
- Share activations between different prompts
|
||||
- Access full model internals without reimplementation
|
||||
|
||||
**Consider alternatives when:**
|
||||
- You want consistent API across models → Use **TransformerLens**
|
||||
- You need declarative, shareable interventions → Use **pyvene**
|
||||
- You're training SAEs → Use **SAELens**
|
||||
- You only work with small models locally → **TransformerLens** may be simpler
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Basic installation
|
||||
pip install nnsight
|
||||
|
||||
# For vLLM support
|
||||
pip install "nnsight[vllm]"
|
||||
```
|
||||
|
||||
For remote NDIF execution, sign up at [login.ndif.us](https://login.ndif.us) for an API key.
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### LanguageModel Wrapper
|
||||
|
||||
```python
|
||||
from nnsight import LanguageModel
|
||||
|
||||
# Load model (uses HuggingFace under the hood)
|
||||
model = LanguageModel("openai-community/gpt2", device_map="auto")
|
||||
|
||||
# For larger models
|
||||
model = LanguageModel("meta-llama/Llama-3.1-8B", device_map="auto")
|
||||
```
|
||||
|
||||
### Tracing Context
|
||||
|
||||
The `trace` context manager enables deferred execution - operations are collected into a computation graph:
|
||||
|
||||
```python
|
||||
from nnsight import LanguageModel
|
||||
|
||||
model = LanguageModel("gpt2", device_map="auto")
|
||||
|
||||
with model.trace("The Eiffel Tower is in") as tracer:
|
||||
# Access any module's output
|
||||
hidden_states = model.transformer.h[5].output[0].save()
|
||||
|
||||
# Access attention patterns
|
||||
attn = model.transformer.h[5].attn.attn_dropout.input[0][0].save()
|
||||
|
||||
# Modify activations
|
||||
model.transformer.h[8].output[0][:] = 0 # Zero out layer 8
|
||||
|
||||
# Get final output
|
||||
logits = model.output.save()
|
||||
|
||||
# After context exits, access saved values
|
||||
print(hidden_states.shape) # [batch, seq, hidden]
|
||||
```
|
||||
|
||||
### Proxy Objects
|
||||
|
||||
Inside `trace`, module accesses return Proxy objects that record operations:
|
||||
|
||||
```python
|
||||
with model.trace("Hello"):
|
||||
# These are all Proxy objects - operations are deferred
|
||||
h5_out = model.transformer.h[5].output[0] # Proxy
|
||||
h5_mean = h5_out.mean(dim=-1) # Proxy
|
||||
h5_saved = h5_mean.save() # Save for later access
|
||||
```
|
||||
|
||||
## Workflow 1: Activation Analysis
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from nnsight import LanguageModel
|
||||
import torch
|
||||
|
||||
model = LanguageModel("gpt2", device_map="auto")
|
||||
|
||||
prompt = "The capital of France is"
|
||||
|
||||
with model.trace(prompt) as tracer:
|
||||
# 1. Collect activations from multiple layers
|
||||
layer_outputs = []
|
||||
for i in range(12): # GPT-2 has 12 layers
|
||||
layer_out = model.transformer.h[i].output[0].save()
|
||||
layer_outputs.append(layer_out)
|
||||
|
||||
# 2. Get attention patterns
|
||||
attn_patterns = []
|
||||
for i in range(12):
|
||||
# Access attention weights (after softmax)
|
||||
attn = model.transformer.h[i].attn.attn_dropout.input[0][0].save()
|
||||
attn_patterns.append(attn)
|
||||
|
||||
# 3. Get final logits
|
||||
logits = model.output.save()
|
||||
|
||||
# 4. Analyze outside context
|
||||
for i, layer_out in enumerate(layer_outputs):
|
||||
print(f"Layer {i} output shape: {layer_out.shape}")
|
||||
print(f"Layer {i} norm: {layer_out.norm().item():.3f}")
|
||||
|
||||
# 5. Find top predictions
|
||||
probs = torch.softmax(logits[0, -1], dim=-1)
|
||||
top_tokens = probs.topk(5)
|
||||
for token, prob in zip(top_tokens.indices, top_tokens.values):
|
||||
print(f"{model.tokenizer.decode(token)}: {prob.item():.3f}")
|
||||
```
|
||||
|
||||
### Checklist
|
||||
- [ ] Load model with LanguageModel wrapper
|
||||
- [ ] Use trace context for operations
|
||||
- [ ] Call `.save()` on values you need after context
|
||||
- [ ] Access saved values outside context
|
||||
- [ ] Use `.shape`, `.norm()`, etc. for analysis
|
||||
|
||||
## Workflow 2: Activation Patching
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from nnsight import LanguageModel
|
||||
import torch
|
||||
|
||||
model = LanguageModel("gpt2", device_map="auto")
|
||||
|
||||
clean_prompt = "The Eiffel Tower is in"
|
||||
corrupted_prompt = "The Colosseum is in"
|
||||
|
||||
# 1. Get clean activations
|
||||
with model.trace(clean_prompt) as tracer:
|
||||
clean_hidden = model.transformer.h[8].output[0].save()
|
||||
|
||||
# 2. Patch clean into corrupted run
|
||||
with model.trace(corrupted_prompt) as tracer:
|
||||
# Replace layer 8 output with clean activations
|
||||
model.transformer.h[8].output[0][:] = clean_hidden
|
||||
|
||||
patched_logits = model.output.save()
|
||||
|
||||
# 3. Compare predictions
|
||||
paris_token = model.tokenizer.encode(" Paris")[0]
|
||||
rome_token = model.tokenizer.encode(" Rome")[0]
|
||||
|
||||
patched_probs = torch.softmax(patched_logits[0, -1], dim=-1)
|
||||
print(f"Paris prob: {patched_probs[paris_token].item():.3f}")
|
||||
print(f"Rome prob: {patched_probs[rome_token].item():.3f}")
|
||||
```
|
||||
|
||||
### Systematic Patching Sweep
|
||||
|
||||
```python
|
||||
def patch_layer_position(layer, position, clean_cache, corrupted_prompt):
|
||||
"""Patch single layer/position from clean to corrupted."""
|
||||
with model.trace(corrupted_prompt) as tracer:
|
||||
# Get current activation
|
||||
current = model.transformer.h[layer].output[0]
|
||||
|
||||
# Patch only specific position
|
||||
current[:, position, :] = clean_cache[layer][:, position, :]
|
||||
|
||||
logits = model.output.save()
|
||||
|
||||
return logits
|
||||
|
||||
# Sweep over all layers and positions
|
||||
results = torch.zeros(12, seq_len)
|
||||
for layer in range(12):
|
||||
for pos in range(seq_len):
|
||||
logits = patch_layer_position(layer, pos, clean_hidden, corrupted)
|
||||
results[layer, pos] = compute_metric(logits)
|
||||
```
|
||||
|
||||
## Workflow 3: Remote Execution with NDIF
|
||||
|
||||
Run the same experiments on massive models without local GPUs.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from nnsight import LanguageModel
|
||||
|
||||
# 1. Load large model (will run remotely)
|
||||
model = LanguageModel("meta-llama/Llama-3.1-70B")
|
||||
|
||||
# 2. Same code, just add remote=True
|
||||
with model.trace("The meaning of life is", remote=True) as tracer:
|
||||
# Access internals of 70B model!
|
||||
layer_40_out = model.model.layers[40].output[0].save()
|
||||
logits = model.output.save()
|
||||
|
||||
# 3. Results returned from NDIF
|
||||
print(f"Layer 40 shape: {layer_40_out.shape}")
|
||||
|
||||
# 4. Generation with interventions
|
||||
with model.trace(remote=True) as tracer:
|
||||
with tracer.invoke("What is 2+2?"):
|
||||
# Intervene during generation
|
||||
model.model.layers[20].output[0][:, -1, :] *= 1.5
|
||||
|
||||
output = model.generate(max_new_tokens=50)
|
||||
```
|
||||
|
||||
### NDIF Setup
|
||||
|
||||
1. Sign up at [login.ndif.us](https://login.ndif.us)
|
||||
2. Get API key
|
||||
3. Set environment variable or pass to nnsight:
|
||||
|
||||
```python
|
||||
import os
|
||||
os.environ["NDIF_API_KEY"] = "your_key"
|
||||
|
||||
# Or configure directly
|
||||
from nnsight import CONFIG
|
||||
CONFIG.API_KEY = "your_key"
|
||||
```
|
||||
|
||||
### Available Models on NDIF
|
||||
|
||||
- Llama-3.1-8B, 70B, 405B
|
||||
- DeepSeek-R1 models
|
||||
- Various open-weight models (check [ndif.us](https://ndif.us) for current list)
|
||||
|
||||
## Workflow 4: Cross-Prompt Activation Sharing
|
||||
|
||||
Share activations between different inputs in a single trace.
|
||||
|
||||
```python
|
||||
from nnsight import LanguageModel
|
||||
|
||||
model = LanguageModel("gpt2", device_map="auto")
|
||||
|
||||
with model.trace() as tracer:
|
||||
# First prompt
|
||||
with tracer.invoke("The cat sat on the"):
|
||||
cat_hidden = model.transformer.h[6].output[0].save()
|
||||
|
||||
# Second prompt - inject cat's activations
|
||||
with tracer.invoke("The dog ran through the"):
|
||||
# Replace with cat's activations at layer 6
|
||||
model.transformer.h[6].output[0][:] = cat_hidden
|
||||
dog_with_cat = model.output.save()
|
||||
|
||||
# The dog prompt now has cat's internal representations
|
||||
```
|
||||
|
||||
## Workflow 5: Gradient-Based Analysis
|
||||
|
||||
Access gradients during backward pass.
|
||||
|
||||
```python
|
||||
from nnsight import LanguageModel
|
||||
import torch
|
||||
|
||||
model = LanguageModel("gpt2", device_map="auto")
|
||||
|
||||
with model.trace("The quick brown fox") as tracer:
|
||||
# Save activations and enable gradient
|
||||
hidden = model.transformer.h[5].output[0].save()
|
||||
hidden.retain_grad()
|
||||
|
||||
logits = model.output
|
||||
|
||||
# Compute loss on specific token
|
||||
target_token = model.tokenizer.encode(" jumps")[0]
|
||||
loss = -logits[0, -1, target_token]
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
|
||||
# Access gradients
|
||||
grad = hidden.grad
|
||||
print(f"Gradient shape: {grad.shape}")
|
||||
print(f"Gradient norm: {grad.norm().item():.3f}")
|
||||
```
|
||||
|
||||
**Note**: Gradient access not supported for vLLM or remote execution.
|
||||
|
||||
## Common Issues & Solutions
|
||||
|
||||
### Issue: Module path differs between models
|
||||
```python
|
||||
# GPT-2 structure
|
||||
model.transformer.h[5].output[0]
|
||||
|
||||
# LLaMA structure
|
||||
model.model.layers[5].output[0]
|
||||
|
||||
# Solution: Check model structure
|
||||
print(model._model) # See actual module names
|
||||
```
|
||||
|
||||
### Issue: Forgetting to save
|
||||
```python
|
||||
# WRONG: Value not accessible outside trace
|
||||
with model.trace("Hello"):
|
||||
hidden = model.transformer.h[5].output[0] # Not saved!
|
||||
|
||||
print(hidden) # Error or wrong value
|
||||
|
||||
# RIGHT: Call .save()
|
||||
with model.trace("Hello"):
|
||||
hidden = model.transformer.h[5].output[0].save()
|
||||
|
||||
print(hidden) # Works!
|
||||
```
|
||||
|
||||
### Issue: Remote timeout
|
||||
```python
|
||||
# For long operations, increase timeout
|
||||
with model.trace("prompt", remote=True, timeout=300) as tracer:
|
||||
# Long operation...
|
||||
```
|
||||
|
||||
### Issue: Memory with many saved activations
|
||||
```python
|
||||
# Only save what you need
|
||||
with model.trace("prompt"):
|
||||
# Don't save everything
|
||||
for i in range(100):
|
||||
model.transformer.h[i].output[0].save() # Memory heavy!
|
||||
|
||||
# Better: save specific layers
|
||||
key_layers = [0, 5, 11]
|
||||
for i in key_layers:
|
||||
model.transformer.h[i].output[0].save()
|
||||
```
|
||||
|
||||
### Issue: vLLM gradient limitation
|
||||
```python
|
||||
# vLLM doesn't support gradients
|
||||
# Use standard execution for gradient analysis
|
||||
model = LanguageModel("gpt2", device_map="auto") # Not vLLM
|
||||
```
|
||||
|
||||
## Key API Reference
|
||||
|
||||
| Method/Property | Purpose |
|
||||
|-----------------|---------|
|
||||
| `model.trace(prompt, remote=False)` | Start tracing context |
|
||||
| `proxy.save()` | Save value for access after trace |
|
||||
| `proxy[:]` | Slice/index proxy (assignment patches) |
|
||||
| `tracer.invoke(prompt)` | Add prompt within trace |
|
||||
| `model.generate(...)` | Generate with interventions |
|
||||
| `model.output` | Final model output logits |
|
||||
| `model._model` | Underlying HuggingFace model |
|
||||
|
||||
## Comparison with Other Tools
|
||||
|
||||
| Feature | nnsight | TransformerLens | pyvene |
|
||||
|---------|---------|-----------------|--------|
|
||||
| Any architecture | Yes | Transformers only | Yes |
|
||||
| Remote execution | Yes (NDIF) | No | No |
|
||||
| Consistent API | No | Yes | Yes |
|
||||
| Deferred execution | Yes | No | No |
|
||||
| HuggingFace native | Yes | Reimplemented | Yes |
|
||||
| Shareable configs | No | No | Yes |
|
||||
|
||||
## Reference Documentation
|
||||
|
||||
For detailed API documentation, tutorials, and advanced usage, see the `references/` folder:
|
||||
|
||||
| File | Contents |
|
||||
|------|----------|
|
||||
| [references/README.md](references/README.md) | Overview and quick start guide |
|
||||
| [references/api.md](references/api.md) | Complete API reference for LanguageModel, tracing, proxy objects |
|
||||
| [references/tutorials.md](references/tutorials.md) | Step-by-step tutorials for local and remote interpretability |
|
||||
|
||||
## External Resources
|
||||
|
||||
### Tutorials
|
||||
- [Getting Started](https://nnsight.net/start/)
|
||||
- [Features Overview](https://nnsight.net/features/)
|
||||
- [Remote Execution](https://nnsight.net/notebooks/features/remote_execution/)
|
||||
- [Applied Tutorials](https://nnsight.net/applied_tutorials/)
|
||||
|
||||
### Official Documentation
|
||||
- [Official Docs](https://nnsight.net/documentation/)
|
||||
- [NDIF Info](https://ndif.us/)
|
||||
- [Community Forum](https://discuss.ndif.us/)
|
||||
|
||||
### Papers
|
||||
- [NNsight and NDIF Paper](https://arxiv.org/abs/2407.14561) - Fiotto-Kaufman et al. (ICLR 2025)
|
||||
|
||||
## Architecture Support
|
||||
|
||||
nnsight works with any PyTorch model:
|
||||
- **Transformers**: GPT-2, LLaMA, Mistral, etc.
|
||||
- **State Space Models**: Mamba
|
||||
- **Vision Models**: ViT, CLIP
|
||||
- **Custom architectures**: Any nn.Module
|
||||
|
||||
The key is knowing the module structure to access the right components.
|
||||
@@ -0,0 +1,78 @@
|
||||
# nnsight Reference Documentation
|
||||
|
||||
This directory contains comprehensive reference materials for nnsight.
|
||||
|
||||
## Contents
|
||||
|
||||
- [api.md](api.md) - Complete API reference for LanguageModel, tracing, and proxy objects
|
||||
- [tutorials.md](tutorials.md) - Step-by-step tutorials for local and remote interpretability
|
||||
|
||||
## Quick Links
|
||||
|
||||
- **Official Documentation**: https://nnsight.net/
|
||||
- **GitHub Repository**: https://github.com/ndif-team/nnsight
|
||||
- **NDIF (Remote Execution)**: https://ndif.us/
|
||||
- **Community Forum**: https://discuss.ndif.us/
|
||||
- **Paper**: https://arxiv.org/abs/2407.14561 (ICLR 2025)
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Basic installation
|
||||
pip install nnsight
|
||||
|
||||
# For vLLM support
|
||||
pip install "nnsight[vllm]"
|
||||
```
|
||||
|
||||
## Basic Usage
|
||||
|
||||
```python
|
||||
from nnsight import LanguageModel
|
||||
|
||||
# Load model
|
||||
model = LanguageModel("openai-community/gpt2", device_map="auto")
|
||||
|
||||
# Trace and access internals
|
||||
with model.trace("The Eiffel Tower is in") as tracer:
|
||||
# Access layer output
|
||||
hidden = model.transformer.h[5].output[0].save()
|
||||
|
||||
# Modify activations
|
||||
model.transformer.h[8].output[0][:] *= 0.5
|
||||
|
||||
# Get final output
|
||||
logits = model.output.save()
|
||||
|
||||
# Access saved values outside context
|
||||
print(hidden.shape)
|
||||
```
|
||||
|
||||
## Key Concepts
|
||||
|
||||
### Tracing
|
||||
The `trace()` context enables deferred execution - operations are recorded and executed together.
|
||||
|
||||
### Proxy Objects
|
||||
Inside trace, module accesses return Proxies. Call `.save()` to retrieve values after execution.
|
||||
|
||||
### Remote Execution (NDIF)
|
||||
Run the same code on massive models (70B+) without local GPUs:
|
||||
|
||||
```python
|
||||
# Same code, just add remote=True
|
||||
with model.trace("Hello", remote=True):
|
||||
hidden = model.model.layers[40].output[0].save()
|
||||
```
|
||||
|
||||
## NDIF Setup
|
||||
|
||||
1. Sign up at https://login.ndif.us/
|
||||
2. Get API key
|
||||
3. Set environment variable: `export NDIF_API_KEY=your_key`
|
||||
|
||||
## Available Remote Models
|
||||
|
||||
- Llama-3.1-8B, 70B, 405B
|
||||
- DeepSeek-R1 models
|
||||
- More at https://ndif.us/
|
||||
@@ -0,0 +1,344 @@
|
||||
# nnsight API Reference
|
||||
|
||||
## LanguageModel
|
||||
|
||||
Main class for wrapping language models with intervention capabilities.
|
||||
|
||||
### Loading Models
|
||||
|
||||
```python
|
||||
from nnsight import LanguageModel
|
||||
|
||||
# Basic loading
|
||||
model = LanguageModel("openai-community/gpt2", device_map="auto")
|
||||
|
||||
# Larger models
|
||||
model = LanguageModel("meta-llama/Llama-3.1-8B", device_map="auto")
|
||||
|
||||
# With custom tokenizer settings
|
||||
model = LanguageModel(
|
||||
"gpt2",
|
||||
device_map="auto",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
```
|
||||
|
||||
### Model Attributes
|
||||
|
||||
```python
|
||||
# Access underlying HuggingFace model
|
||||
model._model
|
||||
|
||||
# Access tokenizer
|
||||
model.tokenizer
|
||||
|
||||
# Model config
|
||||
model._model.config
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tracing Context
|
||||
|
||||
The `trace()` method creates a context for deferred execution.
|
||||
|
||||
### Basic Tracing
|
||||
|
||||
```python
|
||||
with model.trace("Hello world") as tracer:
|
||||
# Operations are recorded, not executed immediately
|
||||
hidden = model.transformer.h[5].output[0].save()
|
||||
logits = model.output.save()
|
||||
|
||||
# After context, operations execute and saved values are available
|
||||
print(hidden.shape)
|
||||
```
|
||||
|
||||
### Tracing Parameters
|
||||
|
||||
```python
|
||||
with model.trace(
|
||||
prompt, # Input text or tokens
|
||||
remote=False, # Use NDIF remote execution
|
||||
validate=True, # Validate tensor shapes
|
||||
scan=True, # Scan for shape info
|
||||
) as tracer:
|
||||
...
|
||||
```
|
||||
|
||||
### Remote Execution
|
||||
|
||||
```python
|
||||
# Same code works remotely
|
||||
with model.trace("Hello", remote=True) as tracer:
|
||||
hidden = model.transformer.h[5].output[0].save()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Proxy Objects
|
||||
|
||||
Inside tracing context, accessing modules returns Proxy objects.
|
||||
|
||||
### Accessing Values
|
||||
|
||||
```python
|
||||
with model.trace("Hello") as tracer:
|
||||
# These are Proxy objects
|
||||
layer_output = model.transformer.h[5].output[0]
|
||||
attention = model.transformer.h[5].attn.output
|
||||
|
||||
# Operations create new Proxies
|
||||
mean = layer_output.mean(dim=-1)
|
||||
normed = layer_output / layer_output.norm()
|
||||
```
|
||||
|
||||
### Saving Values
|
||||
|
||||
```python
|
||||
with model.trace("Hello") as tracer:
|
||||
# Must call .save() to access after context
|
||||
hidden = model.transformer.h[5].output[0].save()
|
||||
|
||||
# Now hidden contains actual tensor
|
||||
print(hidden.shape)
|
||||
```
|
||||
|
||||
### Modifying Values
|
||||
|
||||
```python
|
||||
with model.trace("Hello") as tracer:
|
||||
# In-place modification
|
||||
model.transformer.h[5].output[0][:] = 0
|
||||
|
||||
# Replace with computed value
|
||||
model.transformer.h[5].output[0][:] = some_tensor
|
||||
|
||||
# Arithmetic modification
|
||||
model.transformer.h[5].output[0][:] *= 0.5
|
||||
model.transformer.h[5].output[0][:] += steering_vector
|
||||
```
|
||||
|
||||
### Proxy Operations
|
||||
|
||||
```python
|
||||
with model.trace("Hello") as tracer:
|
||||
h = model.transformer.h[5].output[0]
|
||||
|
||||
# Indexing
|
||||
first_token = h[:, 0, :]
|
||||
last_token = h[:, -1, :]
|
||||
|
||||
# PyTorch operations
|
||||
mean = h.mean(dim=-1)
|
||||
norm = h.norm()
|
||||
transposed = h.transpose(1, 2)
|
||||
|
||||
# Save results
|
||||
mean.save()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Module Access Patterns
|
||||
|
||||
### GPT-2 Structure
|
||||
|
||||
```python
|
||||
with model.trace("Hello") as tracer:
|
||||
# Embeddings
|
||||
embed = model.transformer.wte.output.save()
|
||||
pos_embed = model.transformer.wpe.output.save()
|
||||
|
||||
# Layer outputs
|
||||
layer_out = model.transformer.h[5].output[0].save()
|
||||
|
||||
# Attention
|
||||
attn_out = model.transformer.h[5].attn.output.save()
|
||||
|
||||
# MLP
|
||||
mlp_out = model.transformer.h[5].mlp.output.save()
|
||||
|
||||
# Final output
|
||||
logits = model.output.save()
|
||||
```
|
||||
|
||||
### LLaMA Structure
|
||||
|
||||
```python
|
||||
with model.trace("Hello") as tracer:
|
||||
# Embeddings
|
||||
embed = model.model.embed_tokens.output.save()
|
||||
|
||||
# Layer outputs
|
||||
layer_out = model.model.layers[10].output[0].save()
|
||||
|
||||
# Attention
|
||||
attn_out = model.model.layers[10].self_attn.output.save()
|
||||
|
||||
# MLP
|
||||
mlp_out = model.model.layers[10].mlp.output.save()
|
||||
|
||||
# Final output
|
||||
logits = model.output.save()
|
||||
```
|
||||
|
||||
### Finding Module Names
|
||||
|
||||
```python
|
||||
# Print model structure
|
||||
print(model._model)
|
||||
|
||||
# Or iterate
|
||||
for name, module in model._model.named_modules():
|
||||
print(name)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Multiple Prompts (invoke)
|
||||
|
||||
Process multiple prompts in a single trace.
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
with model.trace() as tracer:
|
||||
with tracer.invoke("First prompt"):
|
||||
hidden1 = model.transformer.h[5].output[0].save()
|
||||
|
||||
with tracer.invoke("Second prompt"):
|
||||
hidden2 = model.transformer.h[5].output[0].save()
|
||||
```
|
||||
|
||||
### Cross-Prompt Intervention
|
||||
|
||||
```python
|
||||
with model.trace() as tracer:
|
||||
# Get activations from first prompt
|
||||
with tracer.invoke("The cat sat on the"):
|
||||
cat_hidden = model.transformer.h[6].output[0].save()
|
||||
|
||||
# Inject into second prompt
|
||||
with tracer.invoke("The dog ran through the"):
|
||||
model.transformer.h[6].output[0][:] = cat_hidden
|
||||
output = model.output.save()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Generation
|
||||
|
||||
Generate text with interventions.
|
||||
|
||||
### Basic Generation
|
||||
|
||||
```python
|
||||
with model.trace() as tracer:
|
||||
with tracer.invoke("Once upon a time"):
|
||||
# Intervention during generation
|
||||
model.transformer.h[5].output[0][:] *= 1.2
|
||||
|
||||
output = model.generate(max_new_tokens=50)
|
||||
|
||||
print(model.tokenizer.decode(output[0]))
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Gradients
|
||||
|
||||
Access gradients for analysis (not supported with remote/vLLM).
|
||||
|
||||
```python
|
||||
with model.trace("The quick brown fox") as tracer:
|
||||
hidden = model.transformer.h[5].output[0].save()
|
||||
hidden.retain_grad()
|
||||
|
||||
logits = model.output
|
||||
target_token = model.tokenizer.encode(" jumps")[0]
|
||||
loss = -logits[0, -1, target_token]
|
||||
loss.backward()
|
||||
|
||||
# Access gradient
|
||||
grad = hidden.grad
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## NDIF Remote Execution
|
||||
|
||||
### Setup
|
||||
|
||||
```python
|
||||
import os
|
||||
os.environ["NDIF_API_KEY"] = "your_key"
|
||||
|
||||
# Or configure directly
|
||||
from nnsight import CONFIG
|
||||
CONFIG.set_default_api_key("your_key")
|
||||
```
|
||||
|
||||
### Using Remote
|
||||
|
||||
```python
|
||||
model = LanguageModel("meta-llama/Llama-3.1-70B")
|
||||
|
||||
with model.trace("Hello", remote=True) as tracer:
|
||||
hidden = model.model.layers[40].output[0].save()
|
||||
logits = model.output.save()
|
||||
|
||||
# Results returned from NDIF
|
||||
print(hidden.shape)
|
||||
```
|
||||
|
||||
### Sessions (Batching Requests)
|
||||
|
||||
```python
|
||||
with model.session(remote=True) as session:
|
||||
with model.trace("First prompt"):
|
||||
h1 = model.model.layers[20].output[0].save()
|
||||
|
||||
with model.trace("Second prompt"):
|
||||
h2 = model.model.layers[20].output[0].save()
|
||||
|
||||
# Both run in single NDIF request
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Utility Methods
|
||||
|
||||
### Early Stopping
|
||||
|
||||
```python
|
||||
with model.trace("Hello") as tracer:
|
||||
hidden = model.transformer.h[5].output[0].save()
|
||||
tracer.stop() # Don't run remaining layers
|
||||
```
|
||||
|
||||
### Validation
|
||||
|
||||
```python
|
||||
# Validate shapes before execution
|
||||
with model.trace("Hello", validate=True) as tracer:
|
||||
hidden = model.transformer.h[5].output[0].save()
|
||||
```
|
||||
|
||||
### Module Access Result
|
||||
|
||||
```python
|
||||
with model.trace("Hello") as tracer:
|
||||
# Access result of a method call
|
||||
result = tracer.result
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Common Module Paths
|
||||
|
||||
| Model | Embeddings | Layers | Attention | MLP |
|
||||
|-------|------------|--------|-----------|-----|
|
||||
| GPT-2 | `transformer.wte` | `transformer.h[i]` | `transformer.h[i].attn` | `transformer.h[i].mlp` |
|
||||
| LLaMA | `model.embed_tokens` | `model.layers[i]` | `model.layers[i].self_attn` | `model.layers[i].mlp` |
|
||||
| Mistral | `model.embed_tokens` | `model.layers[i]` | `model.layers[i].self_attn` | `model.layers[i].mlp` |
|
||||
@@ -0,0 +1,300 @@
|
||||
# nnsight Tutorials
|
||||
|
||||
## Tutorial 1: Basic Activation Analysis
|
||||
|
||||
### Goal
|
||||
Load a model, access internal activations, and analyze them.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from nnsight import LanguageModel
|
||||
import torch
|
||||
|
||||
# 1. Load model
|
||||
model = LanguageModel("openai-community/gpt2", device_map="auto")
|
||||
|
||||
# 2. Trace and collect activations
|
||||
prompt = "The capital of France is"
|
||||
|
||||
with model.trace(prompt) as tracer:
|
||||
# Collect from multiple layers
|
||||
activations = {}
|
||||
for i in range(12): # GPT-2 has 12 layers
|
||||
activations[i] = model.transformer.h[i].output[0].save()
|
||||
|
||||
# Get final logits
|
||||
logits = model.output.save()
|
||||
|
||||
# 3. Analyze (outside context)
|
||||
print("Layer-wise activation norms:")
|
||||
for layer, act in activations.items():
|
||||
print(f" Layer {layer}: {act.norm().item():.2f}")
|
||||
|
||||
# 4. Check predictions
|
||||
probs = torch.softmax(logits[0, -1], dim=-1)
|
||||
top_tokens = probs.topk(5)
|
||||
print("\nTop predictions:")
|
||||
for token_id, prob in zip(top_tokens.indices, top_tokens.values):
|
||||
token_str = model.tokenizer.decode(token_id)
|
||||
print(f" {token_str!r}: {prob.item():.3f}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tutorial 2: Activation Patching
|
||||
|
||||
### Goal
|
||||
Patch activations from one prompt into another to test causal relationships.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from nnsight import LanguageModel
|
||||
import torch
|
||||
|
||||
model = LanguageModel("gpt2", device_map="auto")
|
||||
|
||||
clean_prompt = "The Eiffel Tower is in the city of"
|
||||
corrupted_prompt = "The Colosseum is in the city of"
|
||||
|
||||
# 1. Get clean activations
|
||||
with model.trace(clean_prompt) as tracer:
|
||||
clean_hidden = model.transformer.h[8].output[0].save()
|
||||
clean_logits = model.output.save()
|
||||
|
||||
# 2. Define metric
|
||||
paris_token = model.tokenizer.encode(" Paris")[0]
|
||||
rome_token = model.tokenizer.encode(" Rome")[0]
|
||||
|
||||
def logit_diff(logits):
|
||||
return (logits[0, -1, paris_token] - logits[0, -1, rome_token]).item()
|
||||
|
||||
print(f"Clean logit diff: {logit_diff(clean_logits):.3f}")
|
||||
|
||||
# 3. Patch clean into corrupted
|
||||
with model.trace(corrupted_prompt) as tracer:
|
||||
# Replace layer 8 output with clean activations
|
||||
model.transformer.h[8].output[0][:] = clean_hidden
|
||||
patched_logits = model.output.save()
|
||||
|
||||
print(f"Patched logit diff: {logit_diff(patched_logits):.3f}")
|
||||
|
||||
# 4. Systematic patching sweep
|
||||
results = torch.zeros(12) # 12 layers
|
||||
|
||||
for layer in range(12):
|
||||
# Get clean activation for this layer
|
||||
with model.trace(clean_prompt) as tracer:
|
||||
clean_act = model.transformer.h[layer].output[0].save()
|
||||
|
||||
# Patch into corrupted
|
||||
with model.trace(corrupted_prompt) as tracer:
|
||||
model.transformer.h[layer].output[0][:] = clean_act
|
||||
logits = model.output.save()
|
||||
|
||||
results[layer] = logit_diff(logits)
|
||||
print(f"Layer {layer}: {results[layer]:.3f}")
|
||||
|
||||
print(f"\nMost important layer: {results.argmax().item()}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tutorial 3: Cross-Prompt Activation Sharing
|
||||
|
||||
### Goal
|
||||
Transfer activations between different prompts in a single trace.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from nnsight import LanguageModel
|
||||
|
||||
model = LanguageModel("gpt2", device_map="auto")
|
||||
|
||||
with model.trace() as tracer:
|
||||
# First prompt - get "cat" representations
|
||||
with tracer.invoke("The cat sat on the mat"):
|
||||
cat_hidden = model.transformer.h[6].output[0].save()
|
||||
|
||||
# Second prompt - inject "cat" into "dog"
|
||||
with tracer.invoke("The dog ran through the park"):
|
||||
# Replace with cat's activations
|
||||
model.transformer.h[6].output[0][:] = cat_hidden
|
||||
modified_logits = model.output.save()
|
||||
|
||||
# The dog prompt now has cat's internal representations
|
||||
print(f"Modified logits shape: {modified_logits.shape}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tutorial 4: Remote Execution with NDIF
|
||||
|
||||
### Goal
|
||||
Run the same interpretability code on massive models (70B+).
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from nnsight import LanguageModel
|
||||
import os
|
||||
|
||||
# 1. Setup API key
|
||||
os.environ["NDIF_API_KEY"] = "your_key_here"
|
||||
|
||||
# 2. Load large model (runs remotely)
|
||||
model = LanguageModel("meta-llama/Llama-3.1-70B")
|
||||
|
||||
# 3. Same code, just remote=True
|
||||
prompt = "The meaning of life is"
|
||||
|
||||
with model.trace(prompt, remote=True) as tracer:
|
||||
# Access layer 40 of 70B model!
|
||||
hidden = model.model.layers[40].output[0].save()
|
||||
logits = model.output.save()
|
||||
|
||||
# 4. Results returned from NDIF
|
||||
print(f"Hidden shape: {hidden.shape}")
|
||||
print(f"Logits shape: {logits.shape}")
|
||||
|
||||
# 5. Check predictions
|
||||
import torch
|
||||
probs = torch.softmax(logits[0, -1], dim=-1)
|
||||
top_tokens = probs.topk(5)
|
||||
print("\nTop predictions from Llama-70B:")
|
||||
for token_id, prob in zip(top_tokens.indices, top_tokens.values):
|
||||
print(f" {model.tokenizer.decode(token_id)!r}: {prob.item():.3f}")
|
||||
```
|
||||
|
||||
### Batching with Sessions
|
||||
|
||||
```python
|
||||
# Run multiple experiments in one NDIF request
|
||||
with model.session(remote=True) as session:
|
||||
with model.trace("What is 2+2?"):
|
||||
math_hidden = model.model.layers[30].output[0].save()
|
||||
|
||||
with model.trace("The capital of France is"):
|
||||
fact_hidden = model.model.layers[30].output[0].save()
|
||||
|
||||
# Compare representations
|
||||
similarity = torch.cosine_similarity(
|
||||
math_hidden.mean(dim=1),
|
||||
fact_hidden.mean(dim=1),
|
||||
dim=-1
|
||||
)
|
||||
print(f"Similarity: {similarity.item():.3f}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tutorial 5: Steering with Activation Addition
|
||||
|
||||
### Goal
|
||||
Add a steering vector to change model behavior.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from nnsight import LanguageModel
|
||||
import torch
|
||||
|
||||
model = LanguageModel("gpt2", device_map="auto")
|
||||
|
||||
# 1. Get contrasting activations
|
||||
with model.trace("I love this movie, it's wonderful") as tracer:
|
||||
positive_hidden = model.transformer.h[6].output[0].save()
|
||||
|
||||
with model.trace("I hate this movie, it's terrible") as tracer:
|
||||
negative_hidden = model.transformer.h[6].output[0].save()
|
||||
|
||||
# 2. Compute steering direction
|
||||
steering_vector = positive_hidden.mean(dim=1) - negative_hidden.mean(dim=1)
|
||||
|
||||
# 3. Generate without steering
|
||||
test_prompt = "This restaurant is"
|
||||
with model.trace(test_prompt) as tracer:
|
||||
normal_logits = model.output.save()
|
||||
|
||||
# 4. Generate with steering
|
||||
with model.trace(test_prompt) as tracer:
|
||||
# Add steering at layer 6
|
||||
model.transformer.h[6].output[0][:] += 3.0 * steering_vector
|
||||
steered_logits = model.output.save()
|
||||
|
||||
# 5. Compare predictions
|
||||
def top_prediction(logits):
|
||||
token = logits[0, -1].argmax()
|
||||
return model.tokenizer.decode(token)
|
||||
|
||||
print(f"Normal: {top_prediction(normal_logits)}")
|
||||
print(f"Steered (positive): {top_prediction(steered_logits)}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tutorial 6: Logit Lens
|
||||
|
||||
### Goal
|
||||
See what the model "believes" at each layer.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from nnsight import LanguageModel
|
||||
import torch
|
||||
|
||||
model = LanguageModel("gpt2", device_map="auto")
|
||||
|
||||
prompt = "The quick brown fox jumps over the lazy"
|
||||
|
||||
with model.trace(prompt) as tracer:
|
||||
# Collect residual stream at each layer
|
||||
residuals = []
|
||||
for i in range(12):
|
||||
resid = model.transformer.h[i].output[0].save()
|
||||
residuals.append(resid)
|
||||
|
||||
# Access model's unembedding and final layernorm
|
||||
W_U = model._model.lm_head.weight.T # [d_model, vocab]
|
||||
ln_f = model._model.transformer.ln_f
|
||||
|
||||
print("Layer-by-layer predictions for final token:")
|
||||
for i, resid in enumerate(residuals):
|
||||
# Apply final layernorm
|
||||
normed = ln_f(resid)
|
||||
|
||||
# Project to vocabulary
|
||||
layer_logits = normed @ W_U
|
||||
|
||||
# Get prediction
|
||||
probs = torch.softmax(layer_logits[0, -1], dim=-1)
|
||||
top_token = probs.argmax()
|
||||
top_prob = probs[top_token].item()
|
||||
|
||||
print(f"Layer {i}: {model.tokenizer.decode(top_token)!r} ({top_prob:.3f})")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## External Resources
|
||||
|
||||
### Official Resources
|
||||
- [Getting Started](https://nnsight.net/start/)
|
||||
- [Features Overview](https://nnsight.net/features/)
|
||||
- [Documentation](https://nnsight.net/documentation/)
|
||||
- [Tutorials](https://nnsight.net/tutorials/)
|
||||
|
||||
### NDIF Resources
|
||||
- [NDIF Homepage](https://ndif.us/)
|
||||
- [Available Models](https://ndif.us/models)
|
||||
- [API Key Signup](https://login.ndif.us/)
|
||||
|
||||
### Paper
|
||||
- [NNsight and NDIF](https://arxiv.org/abs/2407.14561) - ICLR 2025
|
||||
|
||||
### Community
|
||||
- [Discussion Forum](https://discuss.ndif.us/)
|
||||
- [GitHub Issues](https://github.com/ndif-team/nnsight/issues)
|
||||
@@ -0,0 +1,473 @@
|
||||
---
|
||||
name: pyvene-interventions
|
||||
description: Provides guidance for performing causal interventions on PyTorch models using pyvene's declarative intervention framework. Use when conducting causal tracing, activation patching, interchange intervention training, or testing causal hypotheses about model behavior.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Causal Intervention, pyvene, Activation Patching, Causal Tracing, Interpretability]
|
||||
dependencies: [pyvene>=0.1.8, torch>=2.0.0, transformers>=4.30.0]
|
||||
---
|
||||
|
||||
# pyvene: Causal Interventions for Neural Networks
|
||||
|
||||
pyvene is Stanford NLP's library for performing causal interventions on PyTorch models. It provides a declarative, dict-based framework for activation patching, causal tracing, and interchange intervention training - making intervention experiments reproducible and shareable.
|
||||
|
||||
**GitHub**: [stanfordnlp/pyvene](https://github.com/stanfordnlp/pyvene) (840+ stars)
|
||||
**Paper**: [pyvene: A Library for Understanding and Improving PyTorch Models via Interventions](https://aclanthology.org/2024.naacl-demo.16) (NAACL 2024)
|
||||
|
||||
## When to Use pyvene
|
||||
|
||||
**Use pyvene when you need to:**
|
||||
- Perform causal tracing (ROME-style localization)
|
||||
- Run activation patching experiments
|
||||
- Conduct interchange intervention training (IIT)
|
||||
- Test causal hypotheses about model components
|
||||
- Share/reproduce intervention experiments via HuggingFace
|
||||
- Work with any PyTorch architecture (not just transformers)
|
||||
|
||||
**Consider alternatives when:**
|
||||
- You need exploratory activation analysis → Use **TransformerLens**
|
||||
- You want to train/analyze SAEs → Use **SAELens**
|
||||
- You need remote execution on massive models → Use **nnsight**
|
||||
- You want lower-level control → Use **nnsight**
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install pyvene
|
||||
```
|
||||
|
||||
Standard import:
|
||||
```python
|
||||
import pyvene as pv
|
||||
```
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### IntervenableModel
|
||||
|
||||
The main class that wraps any PyTorch model with intervention capabilities:
|
||||
|
||||
```python
|
||||
import pyvene as pv
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
# Load base model
|
||||
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
|
||||
# Define intervention configuration
|
||||
config = pv.IntervenableConfig(
|
||||
representations=[
|
||||
pv.RepresentationConfig(
|
||||
layer=8,
|
||||
component="block_output",
|
||||
intervention_type=pv.VanillaIntervention,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Create intervenable model
|
||||
intervenable = pv.IntervenableModel(config, model)
|
||||
```
|
||||
|
||||
### Intervention Types
|
||||
|
||||
| Type | Description | Use Case |
|
||||
|------|-------------|----------|
|
||||
| `VanillaIntervention` | Swap activations between runs | Activation patching |
|
||||
| `AdditionIntervention` | Add activations to base run | Steering, ablation |
|
||||
| `SubtractionIntervention` | Subtract activations | Ablation |
|
||||
| `ZeroIntervention` | Zero out activations | Component knockout |
|
||||
| `RotatedSpaceIntervention` | DAS trainable intervention | Causal discovery |
|
||||
| `CollectIntervention` | Collect activations | Probing, analysis |
|
||||
|
||||
### Component Targets
|
||||
|
||||
```python
|
||||
# Available components to intervene on
|
||||
components = [
|
||||
"block_input", # Input to transformer block
|
||||
"block_output", # Output of transformer block
|
||||
"mlp_input", # Input to MLP
|
||||
"mlp_output", # Output of MLP
|
||||
"mlp_activation", # MLP hidden activations
|
||||
"attention_input", # Input to attention
|
||||
"attention_output", # Output of attention
|
||||
"attention_value_output", # Attention value vectors
|
||||
"query_output", # Query vectors
|
||||
"key_output", # Key vectors
|
||||
"value_output", # Value vectors
|
||||
"head_attention_value_output", # Per-head values
|
||||
]
|
||||
```
|
||||
|
||||
## Workflow 1: Causal Tracing (ROME-style)
|
||||
|
||||
Locate where factual associations are stored by corrupting inputs and restoring activations.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
import pyvene as pv
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
import torch
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("gpt2-xl")
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2-xl")
|
||||
|
||||
# 1. Define clean and corrupted inputs
|
||||
clean_prompt = "The Space Needle is in downtown"
|
||||
corrupted_prompt = "The ##### ###### ## ## ########" # Noise
|
||||
|
||||
clean_tokens = tokenizer(clean_prompt, return_tensors="pt")
|
||||
corrupted_tokens = tokenizer(corrupted_prompt, return_tensors="pt")
|
||||
|
||||
# 2. Get clean activations (source)
|
||||
with torch.no_grad():
|
||||
clean_outputs = model(**clean_tokens, output_hidden_states=True)
|
||||
clean_states = clean_outputs.hidden_states
|
||||
|
||||
# 3. Define restoration intervention
|
||||
def run_causal_trace(layer, position):
|
||||
"""Restore clean activation at specific layer and position."""
|
||||
config = pv.IntervenableConfig(
|
||||
representations=[
|
||||
pv.RepresentationConfig(
|
||||
layer=layer,
|
||||
component="block_output",
|
||||
intervention_type=pv.VanillaIntervention,
|
||||
unit="pos",
|
||||
max_number_of_units=1,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
intervenable = pv.IntervenableModel(config, model)
|
||||
|
||||
# Run with intervention
|
||||
_, patched_outputs = intervenable(
|
||||
base=corrupted_tokens,
|
||||
sources=[clean_tokens],
|
||||
unit_locations={"sources->base": ([[[position]]], [[[position]]])},
|
||||
output_original_output=True,
|
||||
)
|
||||
|
||||
# Return probability of correct token
|
||||
probs = torch.softmax(patched_outputs.logits[0, -1], dim=-1)
|
||||
seattle_token = tokenizer.encode(" Seattle")[0]
|
||||
return probs[seattle_token].item()
|
||||
|
||||
# 4. Sweep over layers and positions
|
||||
n_layers = model.config.n_layer
|
||||
seq_len = clean_tokens["input_ids"].shape[1]
|
||||
|
||||
results = torch.zeros(n_layers, seq_len)
|
||||
for layer in range(n_layers):
|
||||
for pos in range(seq_len):
|
||||
results[layer, pos] = run_causal_trace(layer, pos)
|
||||
|
||||
# 5. Visualize (layer x position heatmap)
|
||||
# High values indicate causal importance
|
||||
```
|
||||
|
||||
### Checklist
|
||||
- [ ] Prepare clean prompt with target factual association
|
||||
- [ ] Create corrupted version (noise or counterfactual)
|
||||
- [ ] Define intervention config for each (layer, position)
|
||||
- [ ] Run patching sweep
|
||||
- [ ] Identify causal hotspots in heatmap
|
||||
|
||||
## Workflow 2: Activation Patching for Circuit Analysis
|
||||
|
||||
Test which components are necessary for a specific behavior.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
import pyvene as pv
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
import torch
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
|
||||
# IOI task setup
|
||||
clean_prompt = "When John and Mary went to the store, Mary gave a bottle to"
|
||||
corrupted_prompt = "When John and Mary went to the store, John gave a bottle to"
|
||||
|
||||
clean_tokens = tokenizer(clean_prompt, return_tensors="pt")
|
||||
corrupted_tokens = tokenizer(corrupted_prompt, return_tensors="pt")
|
||||
|
||||
john_token = tokenizer.encode(" John")[0]
|
||||
mary_token = tokenizer.encode(" Mary")[0]
|
||||
|
||||
def logit_diff(logits):
|
||||
"""IO - S logit difference."""
|
||||
return logits[0, -1, john_token] - logits[0, -1, mary_token]
|
||||
|
||||
# Patch attention output at each layer
|
||||
def patch_attention(layer):
|
||||
config = pv.IntervenableConfig(
|
||||
representations=[
|
||||
pv.RepresentationConfig(
|
||||
layer=layer,
|
||||
component="attention_output",
|
||||
intervention_type=pv.VanillaIntervention,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
intervenable = pv.IntervenableModel(config, model)
|
||||
|
||||
_, patched_outputs = intervenable(
|
||||
base=corrupted_tokens,
|
||||
sources=[clean_tokens],
|
||||
)
|
||||
|
||||
return logit_diff(patched_outputs.logits).item()
|
||||
|
||||
# Find which layers matter
|
||||
results = []
|
||||
for layer in range(model.config.n_layer):
|
||||
diff = patch_attention(layer)
|
||||
results.append(diff)
|
||||
print(f"Layer {layer}: logit diff = {diff:.3f}")
|
||||
```
|
||||
|
||||
## Workflow 3: Interchange Intervention Training (IIT)
|
||||
|
||||
Train interventions to discover causal structure.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
import pyvene as pv
|
||||
from transformers import AutoModelForCausalLM
|
||||
import torch
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
|
||||
# 1. Define trainable intervention
|
||||
config = pv.IntervenableConfig(
|
||||
representations=[
|
||||
pv.RepresentationConfig(
|
||||
layer=6,
|
||||
component="block_output",
|
||||
intervention_type=pv.RotatedSpaceIntervention, # Trainable
|
||||
low_rank_dimension=64, # Learn 64-dim subspace
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
intervenable = pv.IntervenableModel(config, model)
|
||||
|
||||
# 2. Set up training
|
||||
optimizer = torch.optim.Adam(
|
||||
intervenable.get_trainable_parameters(),
|
||||
lr=1e-4
|
||||
)
|
||||
|
||||
# 3. Training loop (simplified)
|
||||
for base_input, source_input, target_output in dataloader:
|
||||
optimizer.zero_grad()
|
||||
|
||||
_, outputs = intervenable(
|
||||
base=base_input,
|
||||
sources=[source_input],
|
||||
)
|
||||
|
||||
loss = criterion(outputs.logits, target_output)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# 4. Analyze learned intervention
|
||||
# The rotation matrix reveals causal subspace
|
||||
rotation = intervenable.interventions["layer.6.block_output"][0].rotate_layer
|
||||
```
|
||||
|
||||
### DAS (Distributed Alignment Search)
|
||||
|
||||
```python
|
||||
# Low-rank rotation finds interpretable subspaces
|
||||
config = pv.IntervenableConfig(
|
||||
representations=[
|
||||
pv.RepresentationConfig(
|
||||
layer=8,
|
||||
component="block_output",
|
||||
intervention_type=pv.LowRankRotatedSpaceIntervention,
|
||||
low_rank_dimension=1, # Find 1D causal direction
|
||||
)
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
## Workflow 4: Model Steering (Honest LLaMA)
|
||||
|
||||
Steer model behavior during generation.
|
||||
|
||||
```python
|
||||
import pyvene as pv
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
|
||||
# Load pre-trained steering intervention
|
||||
intervenable = pv.IntervenableModel.load(
|
||||
"zhengxuanzenwu/intervenable_honest_llama2_chat_7B",
|
||||
model=model,
|
||||
)
|
||||
|
||||
# Generate with steering
|
||||
prompt = "Is the earth flat?"
|
||||
inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
# Intervention applied during generation
|
||||
outputs = intervenable.generate(
|
||||
inputs,
|
||||
max_new_tokens=100,
|
||||
do_sample=False,
|
||||
)
|
||||
|
||||
print(tokenizer.decode(outputs[0]))
|
||||
```
|
||||
|
||||
## Saving and Sharing Interventions
|
||||
|
||||
```python
|
||||
# Save locally
|
||||
intervenable.save("./my_intervention")
|
||||
|
||||
# Load from local
|
||||
intervenable = pv.IntervenableModel.load(
|
||||
"./my_intervention",
|
||||
model=model,
|
||||
)
|
||||
|
||||
# Share on HuggingFace
|
||||
intervenable.save_intervention("username/my-intervention")
|
||||
|
||||
# Load from HuggingFace
|
||||
intervenable = pv.IntervenableModel.load(
|
||||
"username/my-intervention",
|
||||
model=model,
|
||||
)
|
||||
```
|
||||
|
||||
## Common Issues & Solutions
|
||||
|
||||
### Issue: Wrong intervention location
|
||||
```python
|
||||
# WRONG: Incorrect component name
|
||||
config = pv.RepresentationConfig(
|
||||
component="mlp", # Not valid!
|
||||
)
|
||||
|
||||
# RIGHT: Use exact component name
|
||||
config = pv.RepresentationConfig(
|
||||
component="mlp_output", # Valid
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Dimension mismatch
|
||||
```python
|
||||
# Ensure source and base have compatible shapes
|
||||
# For position-specific interventions:
|
||||
config = pv.RepresentationConfig(
|
||||
unit="pos",
|
||||
max_number_of_units=1, # Intervene on single position
|
||||
)
|
||||
|
||||
# Specify locations explicitly
|
||||
intervenable(
|
||||
base=base_tokens,
|
||||
sources=[source_tokens],
|
||||
unit_locations={"sources->base": ([[[5]]], [[[5]]])}, # Position 5
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Memory with large models
|
||||
```python
|
||||
# Use gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Or intervene on fewer components
|
||||
config = pv.IntervenableConfig(
|
||||
representations=[
|
||||
pv.RepresentationConfig(
|
||||
layer=8, # Single layer instead of all
|
||||
component="block_output",
|
||||
)
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: LoRA integration
|
||||
```python
|
||||
# pyvene v0.1.8+ supports LoRAs as interventions
|
||||
config = pv.RepresentationConfig(
|
||||
intervention_type=pv.LoRAIntervention,
|
||||
low_rank_dimension=16,
|
||||
)
|
||||
```
|
||||
|
||||
## Key Classes Reference
|
||||
|
||||
| Class | Purpose |
|
||||
|-------|---------|
|
||||
| `IntervenableModel` | Main wrapper for interventions |
|
||||
| `IntervenableConfig` | Configuration container |
|
||||
| `RepresentationConfig` | Single intervention specification |
|
||||
| `VanillaIntervention` | Activation swapping |
|
||||
| `RotatedSpaceIntervention` | Trainable DAS intervention |
|
||||
| `CollectIntervention` | Activation collection |
|
||||
|
||||
## Supported Models
|
||||
|
||||
pyvene works with any PyTorch model. Tested on:
|
||||
- GPT-2 (all sizes)
|
||||
- LLaMA / LLaMA-2
|
||||
- Pythia
|
||||
- Mistral / Mixtral
|
||||
- OPT
|
||||
- BLIP (vision-language)
|
||||
- ESM (protein models)
|
||||
- Mamba (state space)
|
||||
|
||||
## Reference Documentation
|
||||
|
||||
For detailed API documentation, tutorials, and advanced usage, see the `references/` folder:
|
||||
|
||||
| File | Contents |
|
||||
|------|----------|
|
||||
| [references/README.md](references/README.md) | Overview and quick start guide |
|
||||
| [references/api.md](references/api.md) | Complete API reference for IntervenableModel, intervention types, configurations |
|
||||
| [references/tutorials.md](references/tutorials.md) | Step-by-step tutorials for causal tracing, activation patching, DAS |
|
||||
|
||||
## External Resources
|
||||
|
||||
### Tutorials
|
||||
- [pyvene 101](https://stanfordnlp.github.io/pyvene/tutorials/pyvene_101.html)
|
||||
- [Causal Tracing Tutorial](https://stanfordnlp.github.io/pyvene/tutorials/advanced_tutorials/Causal_Tracing.html)
|
||||
- [IOI Circuit Replication](https://stanfordnlp.github.io/pyvene/tutorials/advanced_tutorials/IOI_Replication.html)
|
||||
- [DAS Introduction](https://stanfordnlp.github.io/pyvene/tutorials/advanced_tutorials/DAS_Main_Introduction.html)
|
||||
|
||||
### Papers
|
||||
- [Locating and Editing Factual Associations in GPT](https://arxiv.org/abs/2202.05262) - Meng et al. (2022)
|
||||
- [Inference-Time Intervention](https://arxiv.org/abs/2306.03341) - Li et al. (2023)
|
||||
- [Interpretability in the Wild](https://arxiv.org/abs/2211.00593) - Wang et al. (2022)
|
||||
|
||||
### Official Documentation
|
||||
- [Official Docs](https://stanfordnlp.github.io/pyvene/)
|
||||
- [API Reference](https://stanfordnlp.github.io/pyvene/api/)
|
||||
|
||||
## Comparison with Other Tools
|
||||
|
||||
| Feature | pyvene | TransformerLens | nnsight |
|
||||
|---------|--------|-----------------|---------|
|
||||
| Declarative config | Yes | No | No |
|
||||
| HuggingFace sharing | Yes | No | No |
|
||||
| Trainable interventions | Yes | Limited | Yes |
|
||||
| Any PyTorch model | Yes | Transformers only | Yes |
|
||||
| Remote execution | No | No | Yes (NDIF) |
|
||||
@@ -0,0 +1,73 @@
|
||||
# pyvene Reference Documentation
|
||||
|
||||
This directory contains comprehensive reference materials for pyvene.
|
||||
|
||||
## Contents
|
||||
|
||||
- [api.md](api.md) - Complete API reference for IntervenableModel, intervention types, and configurations
|
||||
- [tutorials.md](tutorials.md) - Step-by-step tutorials for causal tracing, activation patching, and trainable interventions
|
||||
|
||||
## Quick Links
|
||||
|
||||
- **Official Documentation**: https://stanfordnlp.github.io/pyvene/
|
||||
- **GitHub Repository**: https://github.com/stanfordnlp/pyvene
|
||||
- **Paper**: https://arxiv.org/abs/2403.07809 (NAACL 2024)
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install pyvene
|
||||
```
|
||||
|
||||
## Basic Usage
|
||||
|
||||
```python
|
||||
import pyvene as pv
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
# Load model
|
||||
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
|
||||
# Define intervention
|
||||
config = pv.IntervenableConfig(
|
||||
representations=[
|
||||
pv.RepresentationConfig(
|
||||
layer=5,
|
||||
component="block_output",
|
||||
intervention_type=pv.VanillaIntervention,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Create intervenable model
|
||||
intervenable = pv.IntervenableModel(config, model)
|
||||
|
||||
# Run intervention (swap activations from source to base)
|
||||
base_inputs = tokenizer("The cat sat on the", return_tensors="pt")
|
||||
source_inputs = tokenizer("The dog ran through the", return_tensors="pt")
|
||||
|
||||
_, outputs = intervenable(
|
||||
base=base_inputs,
|
||||
sources=[source_inputs],
|
||||
)
|
||||
```
|
||||
|
||||
## Key Concepts
|
||||
|
||||
### Intervention Types
|
||||
- **VanillaIntervention**: Swap activations between runs
|
||||
- **AdditionIntervention**: Add source to base activations
|
||||
- **ZeroIntervention**: Zero out activations (ablation)
|
||||
- **CollectIntervention**: Collect activations without modifying
|
||||
- **RotatedSpaceIntervention**: Trainable intervention for causal discovery
|
||||
|
||||
### Components
|
||||
Target specific parts of the model:
|
||||
- `block_input`, `block_output`
|
||||
- `mlp_input`, `mlp_output`, `mlp_activation`
|
||||
- `attention_input`, `attention_output`
|
||||
- `query_output`, `key_output`, `value_output`
|
||||
|
||||
### HuggingFace Integration
|
||||
Save and load interventions via HuggingFace Hub for reproducibility.
|
||||
@@ -0,0 +1,383 @@
|
||||
# pyvene API Reference
|
||||
|
||||
## IntervenableModel
|
||||
|
||||
The core class that wraps PyTorch models for intervention.
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
import pyvene as pv
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
|
||||
config = pv.IntervenableConfig(
|
||||
representations=[
|
||||
pv.RepresentationConfig(
|
||||
layer=5,
|
||||
component="block_output",
|
||||
intervention_type=pv.VanillaIntervention,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
intervenable = pv.IntervenableModel(config, model)
|
||||
```
|
||||
|
||||
### Forward Pass
|
||||
|
||||
```python
|
||||
# Basic intervention
|
||||
original_output, intervened_output = intervenable(
|
||||
base=base_inputs,
|
||||
sources=[source_inputs],
|
||||
)
|
||||
|
||||
# With unit locations (position-specific)
|
||||
_, outputs = intervenable(
|
||||
base=base_inputs,
|
||||
sources=[source_inputs],
|
||||
unit_locations={"sources->base": ([[[5]]], [[[5]]])}, # Position 5
|
||||
)
|
||||
|
||||
# Return original output too
|
||||
original, intervened = intervenable(
|
||||
base=base_inputs,
|
||||
sources=[source_inputs],
|
||||
output_original_output=True,
|
||||
)
|
||||
```
|
||||
|
||||
### Generation
|
||||
|
||||
```python
|
||||
# Generate with interventions
|
||||
outputs = intervenable.generate(
|
||||
base_inputs,
|
||||
sources=[source_inputs],
|
||||
max_new_tokens=50,
|
||||
do_sample=False,
|
||||
)
|
||||
```
|
||||
|
||||
### Saving and Loading
|
||||
|
||||
```python
|
||||
# Save locally
|
||||
intervenable.save("./my_intervention")
|
||||
|
||||
# Load
|
||||
intervenable = pv.IntervenableModel.load("./my_intervention", model=model)
|
||||
|
||||
# Save to HuggingFace
|
||||
intervenable.save_intervention("username/my-intervention")
|
||||
|
||||
# Load from HuggingFace
|
||||
intervenable = pv.IntervenableModel.load(
|
||||
"username/my-intervention",
|
||||
model=model
|
||||
)
|
||||
```
|
||||
|
||||
### Getting Trainable Parameters
|
||||
|
||||
```python
|
||||
# For trainable interventions
|
||||
params = intervenable.get_trainable_parameters()
|
||||
optimizer = torch.optim.Adam(params, lr=1e-4)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## IntervenableConfig
|
||||
|
||||
Configuration container for interventions.
|
||||
|
||||
### Basic Config
|
||||
|
||||
```python
|
||||
config = pv.IntervenableConfig(
|
||||
representations=[
|
||||
pv.RepresentationConfig(...)
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
### Multiple Interventions
|
||||
|
||||
```python
|
||||
config = pv.IntervenableConfig(
|
||||
representations=[
|
||||
pv.RepresentationConfig(layer=3, component="block_output", ...),
|
||||
pv.RepresentationConfig(layer=5, component="mlp_output", ...),
|
||||
pv.RepresentationConfig(layer=7, component="attention_output", ...),
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## RepresentationConfig
|
||||
|
||||
Specifies a single intervention target.
|
||||
|
||||
### Parameters
|
||||
|
||||
| Parameter | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `layer` | int | Layer index |
|
||||
| `component` | str | Component to intervene on |
|
||||
| `intervention_type` | type | Intervention class |
|
||||
| `unit` | str | Intervention unit ("pos", "h", etc.) |
|
||||
| `max_number_of_units` | int | Max units to intervene |
|
||||
| `low_rank_dimension` | int | For trainable interventions |
|
||||
| `subspace_partition` | list | Dimension ranges |
|
||||
|
||||
### Components
|
||||
|
||||
| Component | Description |
|
||||
|-----------|-------------|
|
||||
| `block_input` | Input to transformer block |
|
||||
| `block_output` | Output of transformer block |
|
||||
| `mlp_input` | Input to MLP |
|
||||
| `mlp_output` | Output of MLP |
|
||||
| `mlp_activation` | MLP hidden activations |
|
||||
| `attention_input` | Input to attention |
|
||||
| `attention_output` | Output of attention |
|
||||
| `attention_value_output` | Attention values |
|
||||
| `query_output` | Query vectors |
|
||||
| `key_output` | Key vectors |
|
||||
| `value_output` | Value vectors |
|
||||
| `head_attention_value_output` | Per-head values |
|
||||
|
||||
### Example Configs
|
||||
|
||||
```python
|
||||
# Position-specific intervention
|
||||
pv.RepresentationConfig(
|
||||
layer=5,
|
||||
component="block_output",
|
||||
intervention_type=pv.VanillaIntervention,
|
||||
unit="pos",
|
||||
max_number_of_units=1,
|
||||
)
|
||||
|
||||
# Trainable low-rank intervention
|
||||
pv.RepresentationConfig(
|
||||
layer=5,
|
||||
component="block_output",
|
||||
intervention_type=pv.LowRankRotatedSpaceIntervention,
|
||||
low_rank_dimension=64,
|
||||
)
|
||||
|
||||
# Subspace intervention
|
||||
pv.RepresentationConfig(
|
||||
layer=5,
|
||||
component="block_output",
|
||||
intervention_type=pv.VanillaIntervention,
|
||||
subspace_partition=[[0, 256], [256, 512]], # First 512 dims split
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Intervention Types
|
||||
|
||||
### Basic Interventions
|
||||
|
||||
#### VanillaIntervention
|
||||
Replaces base activations with source activations.
|
||||
|
||||
```python
|
||||
pv.RepresentationConfig(
|
||||
intervention_type=pv.VanillaIntervention,
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
#### AdditionIntervention
|
||||
Adds source activations to base.
|
||||
|
||||
```python
|
||||
pv.RepresentationConfig(
|
||||
intervention_type=pv.AdditionIntervention,
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
#### SubtractionIntervention
|
||||
Subtracts source from base.
|
||||
|
||||
```python
|
||||
pv.RepresentationConfig(
|
||||
intervention_type=pv.SubtractionIntervention,
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
#### ZeroIntervention
|
||||
Sets activations to zero (ablation).
|
||||
|
||||
```python
|
||||
pv.RepresentationConfig(
|
||||
intervention_type=pv.ZeroIntervention,
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
#### CollectIntervention
|
||||
Collects activations without modification.
|
||||
|
||||
```python
|
||||
config = pv.IntervenableConfig(
|
||||
representations=[
|
||||
pv.RepresentationConfig(
|
||||
layer=5,
|
||||
component="block_output",
|
||||
intervention_type=pv.CollectIntervention,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
intervenable = pv.IntervenableModel(config, model)
|
||||
_, collected = intervenable(base=inputs)
|
||||
# collected contains the activations
|
||||
```
|
||||
|
||||
### Trainable Interventions
|
||||
|
||||
#### RotatedSpaceIntervention
|
||||
Full-rank trainable rotation.
|
||||
|
||||
```python
|
||||
pv.RepresentationConfig(
|
||||
intervention_type=pv.RotatedSpaceIntervention,
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
#### LowRankRotatedSpaceIntervention
|
||||
Low-rank trainable intervention (DAS).
|
||||
|
||||
```python
|
||||
pv.RepresentationConfig(
|
||||
intervention_type=pv.LowRankRotatedSpaceIntervention,
|
||||
low_rank_dimension=64,
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
#### BoundlessRotatedSpaceIntervention
|
||||
Boundless DAS variant.
|
||||
|
||||
```python
|
||||
pv.RepresentationConfig(
|
||||
intervention_type=pv.BoundlessRotatedSpaceIntervention,
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
#### SigmoidMaskIntervention
|
||||
Learnable binary mask.
|
||||
|
||||
```python
|
||||
pv.RepresentationConfig(
|
||||
intervention_type=pv.SigmoidMaskIntervention,
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Unit Locations
|
||||
|
||||
Specify exactly where to intervene.
|
||||
|
||||
### Format
|
||||
|
||||
```python
|
||||
unit_locations = {
|
||||
"sources->base": (source_locations, base_locations)
|
||||
}
|
||||
```
|
||||
|
||||
### Examples
|
||||
|
||||
```python
|
||||
# Single position
|
||||
unit_locations = {"sources->base": ([[[5]]], [[[5]]])}
|
||||
|
||||
# Multiple positions
|
||||
unit_locations = {"sources->base": ([[[3, 5, 7]]], [[[3, 5, 7]]])}
|
||||
|
||||
# Different source and base positions
|
||||
unit_locations = {"sources->base": ([[[5]]], [[[10]]])}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Supported Models
|
||||
|
||||
pyvene works with any PyTorch model. Officially tested:
|
||||
|
||||
| Family | Models |
|
||||
|--------|--------|
|
||||
| GPT-2 | gpt2, gpt2-medium, gpt2-large, gpt2-xl |
|
||||
| LLaMA | llama-7b, llama-2-7b, llama-2-13b |
|
||||
| Pythia | pythia-70m to pythia-12b |
|
||||
| Mistral | mistral-7b, mixtral-8x7b |
|
||||
| Gemma | gemma-2b, gemma-7b |
|
||||
| Vision | BLIP, LLaVA |
|
||||
| Other | OPT, Phi, Qwen, ESM, Mamba |
|
||||
|
||||
---
|
||||
|
||||
## Quick Reference: Common Patterns
|
||||
|
||||
### Activation Patching
|
||||
```python
|
||||
config = pv.IntervenableConfig(
|
||||
representations=[
|
||||
pv.RepresentationConfig(
|
||||
layer=layer,
|
||||
component="block_output",
|
||||
intervention_type=pv.VanillaIntervention,
|
||||
)
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
### Causal Tracing (ROME-style)
|
||||
```python
|
||||
config = pv.IntervenableConfig(
|
||||
representations=[
|
||||
# First corrupt with noise
|
||||
pv.RepresentationConfig(
|
||||
layer=0,
|
||||
component="block_input",
|
||||
intervention_type=pv.NoiseIntervention,
|
||||
),
|
||||
# Then restore at target layer
|
||||
pv.RepresentationConfig(
|
||||
layer=target_layer,
|
||||
component="block_output",
|
||||
intervention_type=pv.VanillaIntervention,
|
||||
),
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
### DAS (Distributed Alignment Search)
|
||||
```python
|
||||
config = pv.IntervenableConfig(
|
||||
representations=[
|
||||
pv.RepresentationConfig(
|
||||
layer=layer,
|
||||
component="block_output",
|
||||
intervention_type=pv.LowRankRotatedSpaceIntervention,
|
||||
low_rank_dimension=1, # Find 1D causal direction
|
||||
)
|
||||
]
|
||||
)
|
||||
```
|
||||
@@ -0,0 +1,376 @@
|
||||
# pyvene Tutorials
|
||||
|
||||
## Tutorial 1: Basic Activation Patching
|
||||
|
||||
### Goal
|
||||
Swap activations between two prompts to test causal relationships.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
import pyvene as pv
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
import torch
|
||||
|
||||
# 1. Load model
|
||||
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
|
||||
# 2. Prepare inputs
|
||||
base_prompt = "The Colosseum is in the city of"
|
||||
source_prompt = "The Eiffel Tower is in the city of"
|
||||
|
||||
base_inputs = tokenizer(base_prompt, return_tensors="pt")
|
||||
source_inputs = tokenizer(source_prompt, return_tensors="pt")
|
||||
|
||||
# 3. Define intervention (patch layer 8)
|
||||
config = pv.IntervenableConfig(
|
||||
representations=[
|
||||
pv.RepresentationConfig(
|
||||
layer=8,
|
||||
component="block_output",
|
||||
intervention_type=pv.VanillaIntervention,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
intervenable = pv.IntervenableModel(config, model)
|
||||
|
||||
# 4. Run intervention
|
||||
_, patched_outputs = intervenable(
|
||||
base=base_inputs,
|
||||
sources=[source_inputs],
|
||||
)
|
||||
|
||||
# 5. Check predictions
|
||||
patched_logits = patched_outputs.logits
|
||||
probs = torch.softmax(patched_logits[0, -1], dim=-1)
|
||||
|
||||
rome_token = tokenizer.encode(" Rome")[0]
|
||||
paris_token = tokenizer.encode(" Paris")[0]
|
||||
|
||||
print(f"P(Rome): {probs[rome_token].item():.4f}")
|
||||
print(f"P(Paris): {probs[paris_token].item():.4f}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tutorial 2: Causal Tracing (ROME-style)
|
||||
|
||||
### Goal
|
||||
Locate where factual associations are stored by corrupting inputs and restoring activations.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
import pyvene as pv
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
import torch
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("gpt2-xl")
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2-xl")
|
||||
|
||||
# 1. Define prompts
|
||||
clean_prompt = "The Space Needle is in downtown"
|
||||
# We'll corrupt by adding noise to embeddings
|
||||
|
||||
clean_inputs = tokenizer(clean_prompt, return_tensors="pt")
|
||||
seattle_token = tokenizer.encode(" Seattle")[0]
|
||||
|
||||
# 2. Get clean baseline
|
||||
with torch.no_grad():
|
||||
clean_outputs = model(**clean_inputs)
|
||||
clean_prob = torch.softmax(clean_outputs.logits[0, -1], dim=-1)[seattle_token].item()
|
||||
|
||||
print(f"Clean P(Seattle): {clean_prob:.4f}")
|
||||
|
||||
# 3. Sweep over layers - corrupt input, restore at each layer
|
||||
results = []
|
||||
|
||||
for restore_layer in range(model.config.n_layer):
|
||||
# Config: add noise at input, restore at target layer
|
||||
config = pv.IntervenableConfig(
|
||||
representations=[
|
||||
# Noise intervention at embedding
|
||||
pv.RepresentationConfig(
|
||||
layer=0,
|
||||
component="block_input",
|
||||
intervention_type=pv.NoiseIntervention,
|
||||
),
|
||||
# Restore clean at target layer
|
||||
pv.RepresentationConfig(
|
||||
layer=restore_layer,
|
||||
component="block_output",
|
||||
intervention_type=pv.VanillaIntervention,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
intervenable = pv.IntervenableModel(config, model)
|
||||
|
||||
# Source is clean (for restoration), base gets noise
|
||||
_, outputs = intervenable(
|
||||
base=clean_inputs,
|
||||
sources=[clean_inputs], # Restore from clean
|
||||
)
|
||||
|
||||
prob = torch.softmax(outputs.logits[0, -1], dim=-1)[seattle_token].item()
|
||||
results.append(prob)
|
||||
print(f"Restore at layer {restore_layer}: P(Seattle) = {prob:.4f}")
|
||||
|
||||
# 4. Find critical layers (where restoration helps most)
|
||||
import numpy as np
|
||||
results = np.array(results)
|
||||
critical_layers = np.argsort(results)[-5:]
|
||||
print(f"\nMost critical layers: {critical_layers}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tutorial 3: Trainable Interventions (DAS)
|
||||
|
||||
### Goal
|
||||
Learn a low-rank intervention that achieves a target counterfactual behavior.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
import pyvene as pv
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
import torch
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
|
||||
# 1. Define trainable intervention
|
||||
config = pv.IntervenableConfig(
|
||||
representations=[
|
||||
pv.RepresentationConfig(
|
||||
layer=6,
|
||||
component="block_output",
|
||||
intervention_type=pv.LowRankRotatedSpaceIntervention,
|
||||
low_rank_dimension=64, # Learn 64-dim subspace
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
intervenable = pv.IntervenableModel(config, model)
|
||||
|
||||
# 2. Setup optimizer
|
||||
optimizer = torch.optim.Adam(
|
||||
intervenable.get_trainable_parameters(),
|
||||
lr=1e-3
|
||||
)
|
||||
|
||||
# 3. Training data (simplified example)
|
||||
# Goal: Make model predict "Paris" instead of "Rome"
|
||||
base_prompt = "The capital of Italy is"
|
||||
target_token = tokenizer.encode(" Paris")[0]
|
||||
|
||||
base_inputs = tokenizer(base_prompt, return_tensors="pt")
|
||||
|
||||
# 4. Training loop
|
||||
for step in range(100):
|
||||
optimizer.zero_grad()
|
||||
|
||||
_, outputs = intervenable(
|
||||
base=base_inputs,
|
||||
sources=[base_inputs], # Self-intervention
|
||||
)
|
||||
|
||||
# Loss: maximize probability of target token
|
||||
logits = outputs.logits[0, -1]
|
||||
loss = -torch.log_softmax(logits, dim=-1)[target_token]
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if step % 20 == 0:
|
||||
prob = torch.softmax(logits.detach(), dim=-1)[target_token].item()
|
||||
print(f"Step {step}: loss={loss.item():.4f}, P(Paris)={prob:.4f}")
|
||||
|
||||
# 5. Analyze learned rotation
|
||||
rotation = intervenable.interventions["layer.6.comp.block_output.unit.pos.nunit.1#0"][0]
|
||||
print(f"Learned rotation shape: {rotation.rotate_layer.weight.shape}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tutorial 4: Position-Specific Intervention
|
||||
|
||||
### Goal
|
||||
Intervene at specific token positions only.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
import pyvene as pv
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
|
||||
# 1. Setup
|
||||
base_prompt = "John and Mary went to the store"
|
||||
source_prompt = "Alice and Bob went to the store"
|
||||
|
||||
base_inputs = tokenizer(base_prompt, return_tensors="pt")
|
||||
source_inputs = tokenizer(source_prompt, return_tensors="pt")
|
||||
|
||||
# 2. Position-specific config
|
||||
config = pv.IntervenableConfig(
|
||||
representations=[
|
||||
pv.RepresentationConfig(
|
||||
layer=5,
|
||||
component="block_output",
|
||||
intervention_type=pv.VanillaIntervention,
|
||||
unit="pos",
|
||||
max_number_of_units=1, # Single position
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
intervenable = pv.IntervenableModel(config, model)
|
||||
|
||||
# 3. Intervene at position 0 only (first name)
|
||||
_, outputs = intervenable(
|
||||
base=base_inputs,
|
||||
sources=[source_inputs],
|
||||
unit_locations={"sources->base": ([[[0]]], [[[0]]])},
|
||||
)
|
||||
|
||||
# 4. Intervene at multiple positions
|
||||
_, outputs = intervenable(
|
||||
base=base_inputs,
|
||||
sources=[source_inputs],
|
||||
unit_locations={"sources->base": ([[[0, 2]]], [[[0, 2]]])},
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tutorial 5: Collecting Activations
|
||||
|
||||
### Goal
|
||||
Extract activations without modifying them.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
import pyvene as pv
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
|
||||
# 1. Config with CollectIntervention
|
||||
config = pv.IntervenableConfig(
|
||||
representations=[
|
||||
pv.RepresentationConfig(
|
||||
layer=5,
|
||||
component="block_output",
|
||||
intervention_type=pv.CollectIntervention,
|
||||
),
|
||||
pv.RepresentationConfig(
|
||||
layer=10,
|
||||
component="attention_output",
|
||||
intervention_type=pv.CollectIntervention,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
intervenable = pv.IntervenableModel(config, model)
|
||||
|
||||
# 2. Run and collect
|
||||
inputs = tokenizer("Hello world", return_tensors="pt")
|
||||
_, collected = intervenable(base=inputs)
|
||||
|
||||
# 3. Access collected activations
|
||||
layer5_output = collected[0]
|
||||
layer10_attn = collected[1]
|
||||
|
||||
print(f"Layer 5 block output shape: {layer5_output.shape}")
|
||||
print(f"Layer 10 attention output shape: {layer10_attn.shape}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tutorial 6: Generation with Interventions
|
||||
|
||||
### Goal
|
||||
Apply interventions during text generation.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
import pyvene as pv
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# 1. Get steering direction (happy vs sad)
|
||||
happy_inputs = tokenizer("I am very happy and", return_tensors="pt")
|
||||
sad_inputs = tokenizer("I am very sad and", return_tensors="pt")
|
||||
|
||||
# Collect activations
|
||||
config = pv.IntervenableConfig(
|
||||
representations=[
|
||||
pv.RepresentationConfig(
|
||||
layer=6,
|
||||
component="mlp_output",
|
||||
intervention_type=pv.CollectIntervention,
|
||||
)
|
||||
]
|
||||
)
|
||||
collector = pv.IntervenableModel(config, model)
|
||||
|
||||
_, happy_acts = collector(base=happy_inputs)
|
||||
_, sad_acts = collector(base=sad_inputs)
|
||||
|
||||
steering_direction = happy_acts[0].mean(dim=1) - sad_acts[0].mean(dim=1)
|
||||
|
||||
# 2. Config for steering during generation
|
||||
config = pv.IntervenableConfig(
|
||||
representations=[
|
||||
pv.RepresentationConfig(
|
||||
layer=6,
|
||||
component="mlp_output",
|
||||
intervention_type=pv.AdditionIntervention,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
intervenable = pv.IntervenableModel(config, model)
|
||||
|
||||
# 3. Generate with steering
|
||||
prompt = "Today I feel"
|
||||
inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
# Create source with steering direction
|
||||
# (This is simplified - actual implementation varies)
|
||||
output = intervenable.generate(
|
||||
inputs,
|
||||
max_new_tokens=20,
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
)
|
||||
|
||||
print(tokenizer.decode(output[0]))
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## External Resources
|
||||
|
||||
### Official Tutorials
|
||||
- [pyvene 101](https://stanfordnlp.github.io/pyvene/tutorials/pyvene_101.html)
|
||||
- [Causal Tracing](https://stanfordnlp.github.io/pyvene/tutorials/advanced_tutorials/Causal_Tracing.html)
|
||||
- [DAS Introduction](https://stanfordnlp.github.io/pyvene/tutorials/advanced_tutorials/DAS_Main_Introduction.html)
|
||||
- [IOI Replication](https://stanfordnlp.github.io/pyvene/tutorials/advanced_tutorials/IOI_Replication.html)
|
||||
|
||||
### Papers
|
||||
- [pyvene Paper](https://arxiv.org/abs/2403.07809) - NAACL 2024
|
||||
- [ROME](https://arxiv.org/abs/2202.05262) - Meng et al. (2022)
|
||||
- [Inference-Time Intervention](https://arxiv.org/abs/2306.03341) - Li et al. (2023)
|
||||
@@ -0,0 +1,386 @@
|
||||
---
|
||||
name: sparse-autoencoder-training
|
||||
description: Provides guidance for training and analyzing Sparse Autoencoders (SAEs) using SAELens to decompose neural network activations into interpretable features. Use when discovering interpretable features, analyzing superposition, or studying monosemantic representations in language models.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Sparse Autoencoders, SAE, Mechanistic Interpretability, Feature Discovery, Superposition]
|
||||
dependencies: [sae-lens>=6.0.0, transformer-lens>=2.0.0, torch>=2.0.0]
|
||||
---
|
||||
|
||||
# SAELens: Sparse Autoencoders for Mechanistic Interpretability
|
||||
|
||||
SAELens is the primary library for training and analyzing Sparse Autoencoders (SAEs) - a technique for decomposing polysemantic neural network activations into sparse, interpretable features. Based on Anthropic's groundbreaking research on monosemanticity.
|
||||
|
||||
**GitHub**: [jbloomAus/SAELens](https://github.com/jbloomAus/SAELens) (1,100+ stars)
|
||||
|
||||
## The Problem: Polysemanticity & Superposition
|
||||
|
||||
Individual neurons in neural networks are **polysemantic** - they activate in multiple, semantically distinct contexts. This happens because models use **superposition** to represent more features than they have neurons, making interpretability difficult.
|
||||
|
||||
**SAEs solve this** by decomposing dense activations into sparse, monosemantic features - typically only a small number of features activate for any given input, and each feature corresponds to an interpretable concept.
|
||||
|
||||
## When to Use SAELens
|
||||
|
||||
**Use SAELens when you need to:**
|
||||
- Discover interpretable features in model activations
|
||||
- Understand what concepts a model has learned
|
||||
- Study superposition and feature geometry
|
||||
- Perform feature-based steering or ablation
|
||||
- Analyze safety-relevant features (deception, bias, harmful content)
|
||||
|
||||
**Consider alternatives when:**
|
||||
- You need basic activation analysis → Use **TransformerLens** directly
|
||||
- You want causal intervention experiments → Use **pyvene** or **TransformerLens**
|
||||
- You need production steering → Consider direct activation engineering
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install sae-lens
|
||||
```
|
||||
|
||||
Requirements: Python 3.10+, transformer-lens>=2.0.0
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### What SAEs Learn
|
||||
|
||||
SAEs are trained to reconstruct model activations through a sparse bottleneck:
|
||||
|
||||
```
|
||||
Input Activation → Encoder → Sparse Features → Decoder → Reconstructed Activation
|
||||
(d_model) ↓ (d_sae >> d_model) ↓ (d_model)
|
||||
sparsity reconstruction
|
||||
penalty loss
|
||||
```
|
||||
|
||||
**Loss Function**: `MSE(original, reconstructed) + L1_coefficient × L1(features)`
|
||||
|
||||
### Key Validation (Anthropic Research)
|
||||
|
||||
In "Towards Monosemanticity", human evaluators found **70% of SAE features genuinely interpretable**. Features discovered include:
|
||||
- DNA sequences, legal language, HTTP requests
|
||||
- Hebrew text, nutrition statements, code syntax
|
||||
- Sentiment, named entities, grammatical structures
|
||||
|
||||
## Workflow 1: Loading and Analyzing Pre-trained SAEs
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from transformer_lens import HookedTransformer
|
||||
from sae_lens import SAE
|
||||
|
||||
# 1. Load model and pre-trained SAE
|
||||
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
|
||||
sae, cfg_dict, sparsity = SAE.from_pretrained(
|
||||
release="gpt2-small-res-jb",
|
||||
sae_id="blocks.8.hook_resid_pre",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# 2. Get model activations
|
||||
tokens = model.to_tokens("The capital of France is Paris")
|
||||
_, cache = model.run_with_cache(tokens)
|
||||
activations = cache["resid_pre", 8] # [batch, pos, d_model]
|
||||
|
||||
# 3. Encode to SAE features
|
||||
sae_features = sae.encode(activations) # [batch, pos, d_sae]
|
||||
print(f"Active features: {(sae_features > 0).sum()}")
|
||||
|
||||
# 4. Find top features for each position
|
||||
for pos in range(tokens.shape[1]):
|
||||
top_features = sae_features[0, pos].topk(5)
|
||||
token = model.to_str_tokens(tokens[0, pos:pos+1])[0]
|
||||
print(f"Token '{token}': features {top_features.indices.tolist()}")
|
||||
|
||||
# 5. Reconstruct activations
|
||||
reconstructed = sae.decode(sae_features)
|
||||
reconstruction_error = (activations - reconstructed).norm()
|
||||
```
|
||||
|
||||
### Available Pre-trained SAEs
|
||||
|
||||
| Release | Model | Layers |
|
||||
|---------|-------|--------|
|
||||
| `gpt2-small-res-jb` | GPT-2 Small | Multiple residual streams |
|
||||
| `gemma-2b-res` | Gemma 2B | Residual streams |
|
||||
| Various on HuggingFace | Search tag `saelens` | Various |
|
||||
|
||||
### Checklist
|
||||
- [ ] Load model with TransformerLens
|
||||
- [ ] Load matching SAE for target layer
|
||||
- [ ] Encode activations to sparse features
|
||||
- [ ] Identify top-activating features per token
|
||||
- [ ] Validate reconstruction quality
|
||||
|
||||
## Workflow 2: Training a Custom SAE
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from sae_lens import SAE, LanguageModelSAERunnerConfig, SAETrainingRunner
|
||||
|
||||
# 1. Configure training
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
# Model
|
||||
model_name="gpt2-small",
|
||||
hook_name="blocks.8.hook_resid_pre",
|
||||
hook_layer=8,
|
||||
d_in=768, # Model dimension
|
||||
|
||||
# SAE architecture
|
||||
architecture="standard", # or "gated", "topk"
|
||||
d_sae=768 * 8, # Expansion factor of 8
|
||||
activation_fn="relu",
|
||||
|
||||
# Training
|
||||
lr=4e-4,
|
||||
l1_coefficient=8e-5, # Sparsity penalty
|
||||
l1_warm_up_steps=1000,
|
||||
train_batch_size_tokens=4096,
|
||||
training_tokens=100_000_000,
|
||||
|
||||
# Data
|
||||
dataset_path="monology/pile-uncopyrighted",
|
||||
context_size=128,
|
||||
|
||||
# Logging
|
||||
log_to_wandb=True,
|
||||
wandb_project="sae-training",
|
||||
|
||||
# Checkpointing
|
||||
checkpoint_path="checkpoints",
|
||||
n_checkpoints=5,
|
||||
)
|
||||
|
||||
# 2. Train
|
||||
trainer = SAETrainingRunner(cfg)
|
||||
sae = trainer.run()
|
||||
|
||||
# 3. Evaluate
|
||||
print(f"L0 (avg active features): {trainer.metrics['l0']}")
|
||||
print(f"CE Loss Recovered: {trainer.metrics['ce_loss_score']}")
|
||||
```
|
||||
|
||||
### Key Hyperparameters
|
||||
|
||||
| Parameter | Typical Value | Effect |
|
||||
|-----------|---------------|--------|
|
||||
| `d_sae` | 4-16× d_model | More features, higher capacity |
|
||||
| `l1_coefficient` | 5e-5 to 1e-4 | Higher = sparser, less accurate |
|
||||
| `lr` | 1e-4 to 1e-3 | Standard optimizer LR |
|
||||
| `l1_warm_up_steps` | 500-2000 | Prevents early feature death |
|
||||
|
||||
### Evaluation Metrics
|
||||
|
||||
| Metric | Target | Meaning |
|
||||
|--------|--------|---------|
|
||||
| **L0** | 50-200 | Average active features per token |
|
||||
| **CE Loss Score** | 80-95% | Cross-entropy recovered vs original |
|
||||
| **Dead Features** | <5% | Features that never activate |
|
||||
| **Explained Variance** | >90% | Reconstruction quality |
|
||||
|
||||
### Checklist
|
||||
- [ ] Choose target layer and hook point
|
||||
- [ ] Set expansion factor (d_sae = 4-16× d_model)
|
||||
- [ ] Tune L1 coefficient for desired sparsity
|
||||
- [ ] Enable L1 warm-up to prevent dead features
|
||||
- [ ] Monitor metrics during training (W&B)
|
||||
- [ ] Validate L0 and CE loss recovery
|
||||
- [ ] Check dead feature ratio
|
||||
|
||||
## Workflow 3: Feature Analysis and Steering
|
||||
|
||||
### Analyzing Individual Features
|
||||
|
||||
```python
|
||||
from transformer_lens import HookedTransformer
|
||||
from sae_lens import SAE
|
||||
import torch
|
||||
|
||||
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
|
||||
sae, _, _ = SAE.from_pretrained(
|
||||
release="gpt2-small-res-jb",
|
||||
sae_id="blocks.8.hook_resid_pre",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# Find what activates a specific feature
|
||||
feature_idx = 1234
|
||||
test_texts = [
|
||||
"The scientist conducted an experiment",
|
||||
"I love chocolate cake",
|
||||
"The code compiles successfully",
|
||||
"Paris is beautiful in spring",
|
||||
]
|
||||
|
||||
for text in test_texts:
|
||||
tokens = model.to_tokens(text)
|
||||
_, cache = model.run_with_cache(tokens)
|
||||
features = sae.encode(cache["resid_pre", 8])
|
||||
activation = features[0, :, feature_idx].max().item()
|
||||
print(f"{activation:.3f}: {text}")
|
||||
```
|
||||
|
||||
### Feature Steering
|
||||
|
||||
```python
|
||||
def steer_with_feature(model, sae, prompt, feature_idx, strength=5.0):
|
||||
"""Add SAE feature direction to residual stream."""
|
||||
tokens = model.to_tokens(prompt)
|
||||
|
||||
# Get feature direction from decoder
|
||||
feature_direction = sae.W_dec[feature_idx] # [d_model]
|
||||
|
||||
def steering_hook(activation, hook):
|
||||
# Add scaled feature direction at all positions
|
||||
activation += strength * feature_direction
|
||||
return activation
|
||||
|
||||
# Generate with steering
|
||||
output = model.generate(
|
||||
tokens,
|
||||
max_new_tokens=50,
|
||||
fwd_hooks=[("blocks.8.hook_resid_pre", steering_hook)]
|
||||
)
|
||||
return model.to_string(output[0])
|
||||
```
|
||||
|
||||
### Feature Attribution
|
||||
|
||||
```python
|
||||
# Which features most affect a specific output?
|
||||
tokens = model.to_tokens("The capital of France is")
|
||||
_, cache = model.run_with_cache(tokens)
|
||||
|
||||
# Get features at final position
|
||||
features = sae.encode(cache["resid_pre", 8])[0, -1] # [d_sae]
|
||||
|
||||
# Get logit attribution per feature
|
||||
# Feature contribution = feature_activation × decoder_weight × unembedding
|
||||
W_dec = sae.W_dec # [d_sae, d_model]
|
||||
W_U = model.W_U # [d_model, vocab]
|
||||
|
||||
# Contribution to "Paris" logit
|
||||
paris_token = model.to_single_token(" Paris")
|
||||
feature_contributions = features * (W_dec @ W_U[:, paris_token])
|
||||
|
||||
top_features = feature_contributions.topk(10)
|
||||
print("Top features for 'Paris' prediction:")
|
||||
for idx, val in zip(top_features.indices, top_features.values):
|
||||
print(f" Feature {idx.item()}: {val.item():.3f}")
|
||||
```
|
||||
|
||||
## Common Issues & Solutions
|
||||
|
||||
### Issue: High dead feature ratio
|
||||
```python
|
||||
# WRONG: No warm-up, features die early
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
l1_coefficient=1e-4,
|
||||
l1_warm_up_steps=0, # Bad!
|
||||
)
|
||||
|
||||
# RIGHT: Warm-up L1 penalty
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
l1_coefficient=8e-5,
|
||||
l1_warm_up_steps=1000, # Gradually increase
|
||||
use_ghost_grads=True, # Revive dead features
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Poor reconstruction (low CE recovery)
|
||||
```python
|
||||
# Reduce sparsity penalty
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
l1_coefficient=5e-5, # Lower = better reconstruction
|
||||
d_sae=768 * 16, # More capacity
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Features not interpretable
|
||||
```python
|
||||
# Increase sparsity (higher L1)
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
l1_coefficient=1e-4, # Higher = sparser, more interpretable
|
||||
)
|
||||
# Or use TopK architecture
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
architecture="topk",
|
||||
activation_fn_kwargs={"k": 50}, # Exactly 50 active features
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Memory errors during training
|
||||
```python
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
train_batch_size_tokens=2048, # Reduce batch size
|
||||
store_batch_size_prompts=4, # Fewer prompts in buffer
|
||||
n_batches_in_buffer=8, # Smaller activation buffer
|
||||
)
|
||||
```
|
||||
|
||||
## Integration with Neuronpedia
|
||||
|
||||
Browse pre-trained SAE features at [neuronpedia.org](https://neuronpedia.org):
|
||||
|
||||
```python
|
||||
# Features are indexed by SAE ID
|
||||
# Example: gpt2-small layer 8 feature 1234
|
||||
# → neuronpedia.org/gpt2-small/8-res-jb/1234
|
||||
```
|
||||
|
||||
## Key Classes Reference
|
||||
|
||||
| Class | Purpose |
|
||||
|-------|---------|
|
||||
| `SAE` | Sparse Autoencoder model |
|
||||
| `LanguageModelSAERunnerConfig` | Training configuration |
|
||||
| `SAETrainingRunner` | Training loop manager |
|
||||
| `ActivationsStore` | Activation collection and batching |
|
||||
| `HookedSAETransformer` | TransformerLens + SAE integration |
|
||||
|
||||
## Reference Documentation
|
||||
|
||||
For detailed API documentation, tutorials, and advanced usage, see the `references/` folder:
|
||||
|
||||
| File | Contents |
|
||||
|------|----------|
|
||||
| [references/README.md](references/README.md) | Overview and quick start guide |
|
||||
| [references/api.md](references/api.md) | Complete API reference for SAE, TrainingSAE, configurations |
|
||||
| [references/tutorials.md](references/tutorials.md) | Step-by-step tutorials for training, analysis, steering |
|
||||
|
||||
## External Resources
|
||||
|
||||
### Tutorials
|
||||
- [Basic Loading & Analysis](https://github.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
|
||||
- [Training a Sparse Autoencoder](https://github.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
|
||||
- [ARENA SAE Curriculum](https://www.lesswrong.com/posts/LnHowHgmrMbWtpkxx/intro-to-superposition-and-sparse-autoencoders-colab)
|
||||
|
||||
### Papers
|
||||
- [Towards Monosemanticity](https://transformer-circuits.pub/2023/monosemantic-features) - Anthropic (2023)
|
||||
- [Scaling Monosemanticity](https://transformer-circuits.pub/2024/scaling-monosemanticity/) - Anthropic (2024)
|
||||
- [Sparse Autoencoders Find Highly Interpretable Features](https://arxiv.org/abs/2309.08600) - Cunningham et al. (ICLR 2024)
|
||||
|
||||
### Official Documentation
|
||||
- [SAELens Docs](https://jbloomaus.github.io/SAELens/)
|
||||
- [Neuronpedia](https://neuronpedia.org) - Feature browser
|
||||
|
||||
## SAE Architectures
|
||||
|
||||
| Architecture | Description | Use Case |
|
||||
|--------------|-------------|----------|
|
||||
| **Standard** | ReLU + L1 penalty | General purpose |
|
||||
| **Gated** | Learned gating mechanism | Better sparsity control |
|
||||
| **TopK** | Exactly K active features | Consistent sparsity |
|
||||
|
||||
```python
|
||||
# TopK SAE (exactly 50 features active)
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
architecture="topk",
|
||||
activation_fn="topk",
|
||||
activation_fn_kwargs={"k": 50},
|
||||
)
|
||||
```
|
||||
@@ -0,0 +1,70 @@
|
||||
# SAELens Reference Documentation
|
||||
|
||||
This directory contains comprehensive reference materials for SAELens.
|
||||
|
||||
## Contents
|
||||
|
||||
- [api.md](api.md) - Complete API reference for SAE, TrainingSAE, and configuration classes
|
||||
- [tutorials.md](tutorials.md) - Step-by-step tutorials for training and analyzing SAEs
|
||||
- [papers.md](papers.md) - Key research papers on sparse autoencoders
|
||||
|
||||
## Quick Links
|
||||
|
||||
- **GitHub Repository**: https://github.com/jbloomAus/SAELens
|
||||
- **Neuronpedia**: https://neuronpedia.org (browse pre-trained SAE features)
|
||||
- **HuggingFace SAEs**: Search for tag `saelens`
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install sae-lens
|
||||
```
|
||||
|
||||
Requirements: Python 3.10+, transformer-lens>=2.0.0
|
||||
|
||||
## Basic Usage
|
||||
|
||||
```python
|
||||
from transformer_lens import HookedTransformer
|
||||
from sae_lens import SAE
|
||||
|
||||
# Load model and SAE
|
||||
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
|
||||
sae, cfg_dict, sparsity = SAE.from_pretrained(
|
||||
release="gpt2-small-res-jb",
|
||||
sae_id="blocks.8.hook_resid_pre",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# Encode activations to sparse features
|
||||
tokens = model.to_tokens("Hello world")
|
||||
_, cache = model.run_with_cache(tokens)
|
||||
activations = cache["resid_pre", 8]
|
||||
|
||||
features = sae.encode(activations) # Sparse feature activations
|
||||
reconstructed = sae.decode(features) # Reconstructed activations
|
||||
```
|
||||
|
||||
## Key Concepts
|
||||
|
||||
### Sparse Autoencoders
|
||||
SAEs decompose dense neural activations into sparse, interpretable features:
|
||||
- **Encoder**: Maps d_model → d_sae (typically 4-16x expansion)
|
||||
- **ReLU/TopK**: Enforces sparsity
|
||||
- **Decoder**: Reconstructs original activations
|
||||
|
||||
### Training Loss
|
||||
`Loss = MSE(original, reconstructed) + L1_coefficient × L1(features)`
|
||||
|
||||
### Key Metrics
|
||||
- **L0**: Average number of active features (target: 50-200)
|
||||
- **CE Loss Score**: Cross-entropy recovered vs original model (target: 80-95%)
|
||||
- **Dead Features**: Features that never activate (target: <5%)
|
||||
|
||||
## Available Pre-trained SAEs
|
||||
|
||||
| Release | Model | Description |
|
||||
|---------|-------|-------------|
|
||||
| `gpt2-small-res-jb` | GPT-2 Small | Residual stream SAEs |
|
||||
| `gemma-2b-res` | Gemma 2B | Residual stream SAEs |
|
||||
| Various | Search HuggingFace | Community-trained SAEs |
|
||||
@@ -0,0 +1,333 @@
|
||||
# SAELens API Reference
|
||||
|
||||
## SAE Class
|
||||
|
||||
The core class representing a Sparse Autoencoder.
|
||||
|
||||
### Loading Pre-trained SAEs
|
||||
|
||||
```python
|
||||
from sae_lens import SAE
|
||||
|
||||
# From official releases
|
||||
sae, cfg_dict, sparsity = SAE.from_pretrained(
|
||||
release="gpt2-small-res-jb",
|
||||
sae_id="blocks.8.hook_resid_pre",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# From HuggingFace
|
||||
sae, cfg_dict, sparsity = SAE.from_pretrained(
|
||||
release="username/repo-name",
|
||||
sae_id="path/to/sae",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# From local disk
|
||||
sae = SAE.load_from_disk("/path/to/sae", device="cuda")
|
||||
```
|
||||
|
||||
### SAE Attributes
|
||||
|
||||
| Attribute | Shape | Description |
|
||||
|-----------|-------|-------------|
|
||||
| `W_enc` | [d_in, d_sae] | Encoder weights |
|
||||
| `W_dec` | [d_sae, d_in] | Decoder weights |
|
||||
| `b_enc` | [d_sae] | Encoder bias |
|
||||
| `b_dec` | [d_in] | Decoder bias |
|
||||
| `cfg` | SAEConfig | Configuration object |
|
||||
|
||||
### Core Methods
|
||||
|
||||
#### encode()
|
||||
|
||||
```python
|
||||
# Encode activations to sparse features
|
||||
features = sae.encode(activations)
|
||||
# Input: [batch, pos, d_in]
|
||||
# Output: [batch, pos, d_sae]
|
||||
```
|
||||
|
||||
#### decode()
|
||||
|
||||
```python
|
||||
# Reconstruct activations from features
|
||||
reconstructed = sae.decode(features)
|
||||
# Input: [batch, pos, d_sae]
|
||||
# Output: [batch, pos, d_in]
|
||||
```
|
||||
|
||||
#### forward()
|
||||
|
||||
```python
|
||||
# Full forward pass (encode + decode)
|
||||
reconstructed = sae(activations)
|
||||
# Returns reconstructed activations
|
||||
```
|
||||
|
||||
#### save_model()
|
||||
|
||||
```python
|
||||
sae.save_model("/path/to/save")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## SAEConfig
|
||||
|
||||
Configuration class for SAE architecture and training context.
|
||||
|
||||
### Key Parameters
|
||||
|
||||
| Parameter | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `d_in` | int | Input dimension (model's d_model) |
|
||||
| `d_sae` | int | SAE hidden dimension |
|
||||
| `architecture` | str | "standard", "gated", "jumprelu", "topk" |
|
||||
| `activation_fn_str` | str | Activation function name |
|
||||
| `model_name` | str | Source model name |
|
||||
| `hook_name` | str | Hook point in model |
|
||||
| `normalize_activations` | str | Normalization method |
|
||||
| `dtype` | str | Data type |
|
||||
| `device` | str | Device |
|
||||
|
||||
### Accessing Config
|
||||
|
||||
```python
|
||||
print(sae.cfg.d_in) # 768 for GPT-2 small
|
||||
print(sae.cfg.d_sae) # e.g., 24576 (32x expansion)
|
||||
print(sae.cfg.hook_name) # e.g., "blocks.8.hook_resid_pre"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## LanguageModelSAERunnerConfig
|
||||
|
||||
Comprehensive configuration for training SAEs.
|
||||
|
||||
### Example Configuration
|
||||
|
||||
```python
|
||||
from sae_lens import LanguageModelSAERunnerConfig
|
||||
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
# Model and hook
|
||||
model_name="gpt2-small",
|
||||
hook_name="blocks.8.hook_resid_pre",
|
||||
hook_layer=8,
|
||||
d_in=768,
|
||||
|
||||
# SAE architecture
|
||||
architecture="standard", # "standard", "gated", "jumprelu", "topk"
|
||||
d_sae=768 * 8, # Expansion factor
|
||||
activation_fn="relu",
|
||||
|
||||
# Training hyperparameters
|
||||
lr=4e-4,
|
||||
l1_coefficient=8e-5,
|
||||
lp_norm=1.0,
|
||||
lr_scheduler_name="constant",
|
||||
lr_warm_up_steps=500,
|
||||
|
||||
# Sparsity control
|
||||
l1_warm_up_steps=1000,
|
||||
use_ghost_grads=True,
|
||||
feature_sampling_window=1000,
|
||||
dead_feature_window=5000,
|
||||
dead_feature_threshold=1e-8,
|
||||
|
||||
# Data
|
||||
dataset_path="monology/pile-uncopyrighted",
|
||||
streaming=True,
|
||||
context_size=128,
|
||||
|
||||
# Batch sizes
|
||||
train_batch_size_tokens=4096,
|
||||
store_batch_size_prompts=16,
|
||||
n_batches_in_buffer=64,
|
||||
|
||||
# Training duration
|
||||
training_tokens=100_000_000,
|
||||
|
||||
# Logging
|
||||
log_to_wandb=True,
|
||||
wandb_project="sae-training",
|
||||
wandb_log_frequency=100,
|
||||
|
||||
# Checkpointing
|
||||
checkpoint_path="checkpoints",
|
||||
n_checkpoints=5,
|
||||
|
||||
# Hardware
|
||||
device="cuda",
|
||||
dtype="float32",
|
||||
)
|
||||
```
|
||||
|
||||
### Key Parameters Explained
|
||||
|
||||
#### Architecture Parameters
|
||||
|
||||
| Parameter | Description |
|
||||
|-----------|-------------|
|
||||
| `architecture` | SAE type: "standard", "gated", "jumprelu", "topk" |
|
||||
| `d_sae` | Hidden dimension (or use `expansion_factor`) |
|
||||
| `expansion_factor` | Alternative to d_sae: d_sae = d_in × expansion_factor |
|
||||
| `activation_fn` | "relu", "topk", etc. |
|
||||
| `activation_fn_kwargs` | Dict for activation params (e.g., {"k": 50} for topk) |
|
||||
|
||||
#### Sparsity Parameters
|
||||
|
||||
| Parameter | Description |
|
||||
|-----------|-------------|
|
||||
| `l1_coefficient` | L1 penalty weight (higher = sparser) |
|
||||
| `l1_warm_up_steps` | Steps to ramp up L1 penalty |
|
||||
| `use_ghost_grads` | Apply gradients to dead features |
|
||||
| `dead_feature_threshold` | Activation threshold for "dead" |
|
||||
| `dead_feature_window` | Steps to check for dead features |
|
||||
|
||||
#### Learning Rate Parameters
|
||||
|
||||
| Parameter | Description |
|
||||
|-----------|-------------|
|
||||
| `lr` | Base learning rate |
|
||||
| `lr_scheduler_name` | "constant", "cosineannealing", etc. |
|
||||
| `lr_warm_up_steps` | LR warmup steps |
|
||||
| `lr_decay_steps` | Steps for LR decay |
|
||||
|
||||
---
|
||||
|
||||
## SAETrainingRunner
|
||||
|
||||
Main class for executing training.
|
||||
|
||||
### Basic Training
|
||||
|
||||
```python
|
||||
from sae_lens import SAETrainingRunner, LanguageModelSAERunnerConfig
|
||||
|
||||
cfg = LanguageModelSAERunnerConfig(...)
|
||||
runner = SAETrainingRunner(cfg)
|
||||
sae = runner.run()
|
||||
```
|
||||
|
||||
### Accessing Training Metrics
|
||||
|
||||
```python
|
||||
# During training, metrics logged to W&B include:
|
||||
# - l0: Average active features
|
||||
# - ce_loss_score: Cross-entropy recovery
|
||||
# - mse_loss: Reconstruction loss
|
||||
# - l1_loss: Sparsity loss
|
||||
# - dead_features: Count of dead features
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## ActivationsStore
|
||||
|
||||
Manages activation collection and batching.
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
from sae_lens import ActivationsStore
|
||||
|
||||
store = ActivationsStore.from_sae(
|
||||
model=model,
|
||||
sae=sae,
|
||||
store_batch_size_prompts=8,
|
||||
train_batch_size_tokens=4096,
|
||||
n_batches_in_buffer=32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
# Get batch of activations
|
||||
activations = store.get_batch_tokens()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## HookedSAETransformer
|
||||
|
||||
Integration of SAEs with TransformerLens models.
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
from sae_lens import HookedSAETransformer
|
||||
|
||||
# Load model with SAE
|
||||
model = HookedSAETransformer.from_pretrained("gpt2-small")
|
||||
model.add_sae(sae)
|
||||
|
||||
# Run with SAE in the loop
|
||||
output = model.run_with_saes(tokens, saes=[sae])
|
||||
|
||||
# Cache with SAE activations
|
||||
output, cache = model.run_with_cache_with_saes(tokens, saes=[sae])
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## SAE Architectures
|
||||
|
||||
### Standard (ReLU + L1)
|
||||
|
||||
```python
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
architecture="standard",
|
||||
activation_fn="relu",
|
||||
l1_coefficient=8e-5,
|
||||
)
|
||||
```
|
||||
|
||||
### Gated
|
||||
|
||||
```python
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
architecture="gated",
|
||||
)
|
||||
```
|
||||
|
||||
### TopK
|
||||
|
||||
```python
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
architecture="topk",
|
||||
activation_fn="topk",
|
||||
activation_fn_kwargs={"k": 50}, # Exactly 50 active features
|
||||
)
|
||||
```
|
||||
|
||||
### JumpReLU (State-of-the-art)
|
||||
|
||||
```python
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
architecture="jumprelu",
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Utility Functions
|
||||
|
||||
### Upload to HuggingFace
|
||||
|
||||
```python
|
||||
from sae_lens import upload_saes_to_huggingface
|
||||
|
||||
upload_saes_to_huggingface(
|
||||
saes=[sae],
|
||||
repo_id="username/my-saes",
|
||||
token="hf_token",
|
||||
)
|
||||
```
|
||||
|
||||
### Neuronpedia Integration
|
||||
|
||||
```python
|
||||
# Features can be viewed on Neuronpedia
|
||||
# URL format: neuronpedia.org/{model}/{layer}-{sae_type}/{feature_id}
|
||||
# Example: neuronpedia.org/gpt2-small/8-res-jb/1234
|
||||
```
|
||||
@@ -0,0 +1,318 @@
|
||||
# SAELens Tutorials
|
||||
|
||||
## Tutorial 1: Loading and Analyzing Pre-trained SAEs
|
||||
|
||||
### Goal
|
||||
Load a pre-trained SAE and analyze which features activate on specific inputs.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from transformer_lens import HookedTransformer
|
||||
from sae_lens import SAE
|
||||
import torch
|
||||
|
||||
# 1. Load model and SAE
|
||||
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
|
||||
sae, cfg_dict, sparsity = SAE.from_pretrained(
|
||||
release="gpt2-small-res-jb",
|
||||
sae_id="blocks.8.hook_resid_pre",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
print(f"SAE input dim: {sae.cfg.d_in}")
|
||||
print(f"SAE hidden dim: {sae.cfg.d_sae}")
|
||||
print(f"Expansion factor: {sae.cfg.d_sae / sae.cfg.d_in:.1f}x")
|
||||
|
||||
# 2. Get model activations
|
||||
prompt = "The capital of France is Paris"
|
||||
tokens = model.to_tokens(prompt)
|
||||
_, cache = model.run_with_cache(tokens)
|
||||
activations = cache["resid_pre", 8] # [1, seq_len, 768]
|
||||
|
||||
# 3. Encode to SAE features
|
||||
features = sae.encode(activations) # [1, seq_len, d_sae]
|
||||
|
||||
# 4. Analyze sparsity
|
||||
active_per_token = (features > 0).sum(dim=-1)
|
||||
print(f"Average active features per token: {active_per_token.float().mean():.1f}")
|
||||
|
||||
# 5. Find top features for each token
|
||||
str_tokens = model.to_str_tokens(prompt)
|
||||
for pos in range(len(str_tokens)):
|
||||
top_features = features[0, pos].topk(5)
|
||||
print(f"\nToken '{str_tokens[pos]}':")
|
||||
for feat_idx, feat_val in zip(top_features.indices, top_features.values):
|
||||
print(f" Feature {feat_idx.item()}: {feat_val.item():.3f}")
|
||||
|
||||
# 6. Check reconstruction quality
|
||||
reconstructed = sae.decode(features)
|
||||
mse = ((activations - reconstructed) ** 2).mean()
|
||||
print(f"\nReconstruction MSE: {mse.item():.6f}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tutorial 2: Training a Custom SAE
|
||||
|
||||
### Goal
|
||||
Train a Sparse Autoencoder on GPT-2 activations.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner
|
||||
|
||||
# 1. Configure training
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
# Model
|
||||
model_name="gpt2-small",
|
||||
hook_name="blocks.6.hook_resid_pre",
|
||||
hook_layer=6,
|
||||
d_in=768,
|
||||
|
||||
# SAE architecture
|
||||
architecture="standard",
|
||||
d_sae=768 * 8, # 8x expansion
|
||||
activation_fn="relu",
|
||||
|
||||
# Training
|
||||
lr=4e-4,
|
||||
l1_coefficient=8e-5,
|
||||
l1_warm_up_steps=1000,
|
||||
train_batch_size_tokens=4096,
|
||||
training_tokens=10_000_000, # Small run for demo
|
||||
|
||||
# Data
|
||||
dataset_path="monology/pile-uncopyrighted",
|
||||
streaming=True,
|
||||
context_size=128,
|
||||
|
||||
# Dead feature prevention
|
||||
use_ghost_grads=True,
|
||||
dead_feature_window=5000,
|
||||
|
||||
# Logging
|
||||
log_to_wandb=True,
|
||||
wandb_project="sae-training-demo",
|
||||
|
||||
# Hardware
|
||||
device="cuda",
|
||||
dtype="float32",
|
||||
)
|
||||
|
||||
# 2. Train
|
||||
runner = SAETrainingRunner(cfg)
|
||||
sae = runner.run()
|
||||
|
||||
# 3. Save
|
||||
sae.save_model("./my_trained_sae")
|
||||
```
|
||||
|
||||
### Hyperparameter Tuning Guide
|
||||
|
||||
| If you see... | Try... |
|
||||
|---------------|--------|
|
||||
| High L0 (>200) | Increase `l1_coefficient` |
|
||||
| Low CE recovery (<80%) | Decrease `l1_coefficient`, increase `d_sae` |
|
||||
| Many dead features (>5%) | Enable `use_ghost_grads`, increase `l1_warm_up_steps` |
|
||||
| Training instability | Lower `lr`, increase `lr_warm_up_steps` |
|
||||
|
||||
---
|
||||
|
||||
## Tutorial 3: Feature Attribution and Steering
|
||||
|
||||
### Goal
|
||||
Identify which SAE features contribute to specific predictions and use them for steering.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from transformer_lens import HookedTransformer
|
||||
from sae_lens import SAE
|
||||
import torch
|
||||
|
||||
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
|
||||
sae, _, _ = SAE.from_pretrained(
|
||||
release="gpt2-small-res-jb",
|
||||
sae_id="blocks.8.hook_resid_pre",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# 1. Feature attribution for a specific prediction
|
||||
prompt = "The capital of France is"
|
||||
tokens = model.to_tokens(prompt)
|
||||
_, cache = model.run_with_cache(tokens)
|
||||
activations = cache["resid_pre", 8]
|
||||
features = sae.encode(activations)
|
||||
|
||||
# Target token
|
||||
target_token = model.to_single_token(" Paris")
|
||||
|
||||
# Compute feature contributions to target logit
|
||||
# contribution = feature_activation * decoder_weight * unembedding
|
||||
W_dec = sae.W_dec # [d_sae, d_model]
|
||||
W_U = model.W_U # [d_model, d_vocab]
|
||||
|
||||
# Feature direction projected to vocabulary
|
||||
feature_to_logit = W_dec @ W_U # [d_sae, d_vocab]
|
||||
|
||||
# Contribution of each feature to "Paris" at final position
|
||||
feature_acts = features[0, -1] # [d_sae]
|
||||
contributions = feature_acts * feature_to_logit[:, target_token]
|
||||
|
||||
# Top contributing features
|
||||
top_features = contributions.topk(10)
|
||||
print("Top features contributing to 'Paris':")
|
||||
for idx, val in zip(top_features.indices, top_features.values):
|
||||
print(f" Feature {idx.item()}: {val.item():.3f}")
|
||||
|
||||
# 2. Feature steering
|
||||
def steer_with_feature(feature_idx, strength=5.0):
|
||||
"""Add a feature direction to the residual stream."""
|
||||
feature_direction = sae.W_dec[feature_idx] # [d_model]
|
||||
|
||||
def hook(activation, hook_obj):
|
||||
activation[:, -1, :] += strength * feature_direction
|
||||
return activation
|
||||
|
||||
output = model.generate(
|
||||
tokens,
|
||||
max_new_tokens=10,
|
||||
fwd_hooks=[("blocks.8.hook_resid_pre", hook)]
|
||||
)
|
||||
return model.to_string(output[0])
|
||||
|
||||
# Try steering with top feature
|
||||
top_feature_idx = top_features.indices[0].item()
|
||||
print(f"\nSteering with feature {top_feature_idx}:")
|
||||
print(steer_with_feature(top_feature_idx, strength=10.0))
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tutorial 4: Feature Ablation
|
||||
|
||||
### Goal
|
||||
Test the causal importance of features by ablating them.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from transformer_lens import HookedTransformer
|
||||
from sae_lens import SAE
|
||||
import torch
|
||||
|
||||
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
|
||||
sae, _, _ = SAE.from_pretrained(
|
||||
release="gpt2-small-res-jb",
|
||||
sae_id="blocks.8.hook_resid_pre",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
prompt = "The capital of France is"
|
||||
tokens = model.to_tokens(prompt)
|
||||
|
||||
# Baseline prediction
|
||||
baseline_logits = model(tokens)
|
||||
target_token = model.to_single_token(" Paris")
|
||||
baseline_prob = torch.softmax(baseline_logits[0, -1], dim=-1)[target_token].item()
|
||||
print(f"Baseline P(Paris): {baseline_prob:.4f}")
|
||||
|
||||
# Get features to ablate
|
||||
_, cache = model.run_with_cache(tokens)
|
||||
activations = cache["resid_pre", 8]
|
||||
features = sae.encode(activations)
|
||||
top_features = features[0, -1].topk(10).indices
|
||||
|
||||
# Ablate top features one by one
|
||||
for feat_idx in top_features:
|
||||
def ablation_hook(activation, hook, feat_idx=feat_idx):
|
||||
# Encode → zero feature → decode
|
||||
feats = sae.encode(activation)
|
||||
feats[:, :, feat_idx] = 0
|
||||
return sae.decode(feats)
|
||||
|
||||
ablated_logits = model.run_with_hooks(
|
||||
tokens,
|
||||
fwd_hooks=[("blocks.8.hook_resid_pre", ablation_hook)]
|
||||
)
|
||||
ablated_prob = torch.softmax(ablated_logits[0, -1], dim=-1)[target_token].item()
|
||||
change = (ablated_prob - baseline_prob) / baseline_prob * 100
|
||||
print(f"Ablate feature {feat_idx.item()}: P(Paris)={ablated_prob:.4f} ({change:+.1f}%)")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tutorial 5: Comparing Features Across Prompts
|
||||
|
||||
### Goal
|
||||
Find which features activate consistently for a concept.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from transformer_lens import HookedTransformer
|
||||
from sae_lens import SAE
|
||||
import torch
|
||||
|
||||
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
|
||||
sae, _, _ = SAE.from_pretrained(
|
||||
release="gpt2-small-res-jb",
|
||||
sae_id="blocks.8.hook_resid_pre",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# Test prompts about the same concept
|
||||
prompts = [
|
||||
"The Eiffel Tower is located in",
|
||||
"Paris is the capital of",
|
||||
"France's largest city is",
|
||||
"The Louvre museum is in",
|
||||
]
|
||||
|
||||
# Collect feature activations
|
||||
all_features = []
|
||||
for prompt in prompts:
|
||||
tokens = model.to_tokens(prompt)
|
||||
_, cache = model.run_with_cache(tokens)
|
||||
activations = cache["resid_pre", 8]
|
||||
features = sae.encode(activations)
|
||||
# Take max activation across positions
|
||||
max_features = features[0].max(dim=0).values
|
||||
all_features.append(max_features)
|
||||
|
||||
all_features = torch.stack(all_features) # [n_prompts, d_sae]
|
||||
|
||||
# Find features that activate consistently
|
||||
mean_activation = all_features.mean(dim=0)
|
||||
min_activation = all_features.min(dim=0).values
|
||||
|
||||
# Features active in ALL prompts
|
||||
consistent_features = (min_activation > 0.5).nonzero().squeeze(-1)
|
||||
print(f"Features active in all prompts: {len(consistent_features)}")
|
||||
|
||||
# Top consistent features
|
||||
top_consistent = mean_activation[consistent_features].topk(min(10, len(consistent_features)))
|
||||
print("\nTop consistent features (possibly 'France/Paris' related):")
|
||||
for idx, val in zip(top_consistent.indices, top_consistent.values):
|
||||
feat_idx = consistent_features[idx].item()
|
||||
print(f" Feature {feat_idx}: mean activation {val.item():.3f}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## External Resources
|
||||
|
||||
### Official Tutorials
|
||||
- [Basic Loading & Analysis](https://github.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
|
||||
- [Training SAEs](https://github.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
|
||||
- [Logits Lens with Features](https://github.com/jbloomAus/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
|
||||
|
||||
### ARENA Curriculum
|
||||
Comprehensive SAE course: https://www.lesswrong.com/posts/LnHowHgmrMbWtpkxx/intro-to-superposition-and-sparse-autoencoders-colab
|
||||
|
||||
### Key Papers
|
||||
- [Towards Monosemanticity](https://transformer-circuits.pub/2023/monosemantic-features) - Anthropic (2023)
|
||||
- [Scaling Monosemanticity](https://transformer-circuits.pub/2024/scaling-monosemanticity/) - Anthropic (2024)
|
||||
- [Sparse Autoencoders Find Interpretable Features](https://arxiv.org/abs/2309.08600) - ICLR 2024
|
||||
@@ -0,0 +1,346 @@
|
||||
---
|
||||
name: transformer-lens-interpretability
|
||||
description: Provides guidance for mechanistic interpretability research using TransformerLens to inspect and manipulate transformer internals via HookPoints and activation caching. Use when reverse-engineering model algorithms, studying attention patterns, or performing activation patching experiments.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Mechanistic Interpretability, TransformerLens, Activation Patching, Circuit Analysis]
|
||||
dependencies: [transformer-lens>=2.0.0, torch>=2.0.0]
|
||||
---
|
||||
|
||||
# TransformerLens: Mechanistic Interpretability for Transformers
|
||||
|
||||
TransformerLens is the de facto standard library for mechanistic interpretability research on GPT-style language models. Created by Neel Nanda and maintained by Bryce Meyer, it provides clean interfaces to inspect and manipulate model internals via HookPoints on every activation.
|
||||
|
||||
**GitHub**: [TransformerLensOrg/TransformerLens](https://github.com/TransformerLensOrg/TransformerLens) (2,900+ stars)
|
||||
|
||||
## When to Use TransformerLens
|
||||
|
||||
**Use TransformerLens when you need to:**
|
||||
- Reverse-engineer algorithms learned during training
|
||||
- Perform activation patching / causal tracing experiments
|
||||
- Study attention patterns and information flow
|
||||
- Analyze circuits (e.g., induction heads, IOI circuit)
|
||||
- Cache and inspect intermediate activations
|
||||
- Apply direct logit attribution
|
||||
|
||||
**Consider alternatives when:**
|
||||
- You need to work with non-transformer architectures → Use **nnsight** or **pyvene**
|
||||
- You want to train/analyze Sparse Autoencoders → Use **SAELens**
|
||||
- You need remote execution on massive models → Use **nnsight** with NDIF
|
||||
- You want higher-level causal intervention abstractions → Use **pyvene**
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install transformer-lens
|
||||
```
|
||||
|
||||
For development version:
|
||||
```bash
|
||||
pip install git+https://github.com/TransformerLensOrg/TransformerLens
|
||||
```
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### HookedTransformer
|
||||
|
||||
The main class that wraps transformer models with HookPoints on every activation:
|
||||
|
||||
```python
|
||||
from transformer_lens import HookedTransformer
|
||||
|
||||
# Load a model
|
||||
model = HookedTransformer.from_pretrained("gpt2-small")
|
||||
|
||||
# For gated models (LLaMA, Mistral)
|
||||
import os
|
||||
os.environ["HF_TOKEN"] = "your_token"
|
||||
model = HookedTransformer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
```
|
||||
|
||||
### Supported Models (50+)
|
||||
|
||||
| Family | Models |
|
||||
|--------|--------|
|
||||
| GPT-2 | gpt2, gpt2-medium, gpt2-large, gpt2-xl |
|
||||
| LLaMA | llama-7b, llama-13b, llama-2-7b, llama-2-13b |
|
||||
| EleutherAI | pythia-70m to pythia-12b, gpt-neo, gpt-j-6b |
|
||||
| Mistral | mistral-7b, mixtral-8x7b |
|
||||
| Others | phi, qwen, opt, gemma |
|
||||
|
||||
### Activation Caching
|
||||
|
||||
Run the model and cache all intermediate activations:
|
||||
|
||||
```python
|
||||
# Get all activations
|
||||
tokens = model.to_tokens("The Eiffel Tower is in")
|
||||
logits, cache = model.run_with_cache(tokens)
|
||||
|
||||
# Access specific activations
|
||||
residual = cache["resid_post", 5] # Layer 5 residual stream
|
||||
attn_pattern = cache["pattern", 3] # Layer 3 attention pattern
|
||||
mlp_out = cache["mlp_out", 7] # Layer 7 MLP output
|
||||
|
||||
# Filter which activations to cache (saves memory)
|
||||
logits, cache = model.run_with_cache(
|
||||
tokens,
|
||||
names_filter=lambda name: "resid_post" in name
|
||||
)
|
||||
```
|
||||
|
||||
### ActivationCache Keys
|
||||
|
||||
| Key Pattern | Shape | Description |
|
||||
|-------------|-------|-------------|
|
||||
| `resid_pre, layer` | [batch, pos, d_model] | Residual before attention |
|
||||
| `resid_mid, layer` | [batch, pos, d_model] | Residual after attention |
|
||||
| `resid_post, layer` | [batch, pos, d_model] | Residual after MLP |
|
||||
| `attn_out, layer` | [batch, pos, d_model] | Attention output |
|
||||
| `mlp_out, layer` | [batch, pos, d_model] | MLP output |
|
||||
| `pattern, layer` | [batch, head, q_pos, k_pos] | Attention pattern (post-softmax) |
|
||||
| `q, layer` | [batch, pos, head, d_head] | Query vectors |
|
||||
| `k, layer` | [batch, pos, head, d_head] | Key vectors |
|
||||
| `v, layer` | [batch, pos, head, d_head] | Value vectors |
|
||||
|
||||
## Workflow 1: Activation Patching (Causal Tracing)
|
||||
|
||||
Identify which activations causally affect model output by patching clean activations into corrupted runs.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from transformer_lens import HookedTransformer, patching
|
||||
import torch
|
||||
|
||||
model = HookedTransformer.from_pretrained("gpt2-small")
|
||||
|
||||
# 1. Define clean and corrupted prompts
|
||||
clean_prompt = "The Eiffel Tower is in the city of"
|
||||
corrupted_prompt = "The Colosseum is in the city of"
|
||||
|
||||
clean_tokens = model.to_tokens(clean_prompt)
|
||||
corrupted_tokens = model.to_tokens(corrupted_prompt)
|
||||
|
||||
# 2. Get clean activations
|
||||
_, clean_cache = model.run_with_cache(clean_tokens)
|
||||
|
||||
# 3. Define metric (e.g., logit difference)
|
||||
paris_token = model.to_single_token(" Paris")
|
||||
rome_token = model.to_single_token(" Rome")
|
||||
|
||||
def metric(logits):
|
||||
return logits[0, -1, paris_token] - logits[0, -1, rome_token]
|
||||
|
||||
# 4. Patch each position and layer
|
||||
results = torch.zeros(model.cfg.n_layers, clean_tokens.shape[1])
|
||||
|
||||
for layer in range(model.cfg.n_layers):
|
||||
for pos in range(clean_tokens.shape[1]):
|
||||
def patch_hook(activation, hook):
|
||||
activation[0, pos] = clean_cache[hook.name][0, pos]
|
||||
return activation
|
||||
|
||||
patched_logits = model.run_with_hooks(
|
||||
corrupted_tokens,
|
||||
fwd_hooks=[(f"blocks.{layer}.hook_resid_post", patch_hook)]
|
||||
)
|
||||
results[layer, pos] = metric(patched_logits)
|
||||
|
||||
# 5. Visualize results (layer x position heatmap)
|
||||
```
|
||||
|
||||
### Checklist
|
||||
- [ ] Define clean and corrupted inputs that differ minimally
|
||||
- [ ] Choose metric that captures behavior difference
|
||||
- [ ] Cache clean activations
|
||||
- [ ] Systematically patch each (layer, position) combination
|
||||
- [ ] Visualize results as heatmap
|
||||
- [ ] Identify causal hotspots
|
||||
|
||||
## Workflow 2: Circuit Analysis (Indirect Object Identification)
|
||||
|
||||
Replicate the IOI circuit discovery from "Interpretability in the Wild".
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from transformer_lens import HookedTransformer
|
||||
import torch
|
||||
|
||||
model = HookedTransformer.from_pretrained("gpt2-small")
|
||||
|
||||
# IOI task: "When John and Mary went to the store, Mary gave a bottle to"
|
||||
# Model should predict "John" (indirect object)
|
||||
|
||||
prompt = "When John and Mary went to the store, Mary gave a bottle to"
|
||||
tokens = model.to_tokens(prompt)
|
||||
|
||||
# 1. Get baseline logits
|
||||
logits, cache = model.run_with_cache(tokens)
|
||||
|
||||
john_token = model.to_single_token(" John")
|
||||
mary_token = model.to_single_token(" Mary")
|
||||
|
||||
# 2. Compute logit difference (IO - S)
|
||||
logit_diff = logits[0, -1, john_token] - logits[0, -1, mary_token]
|
||||
print(f"Logit difference: {logit_diff.item():.3f}")
|
||||
|
||||
# 3. Direct logit attribution by head
|
||||
def get_head_contribution(layer, head):
|
||||
# Project head output to logits
|
||||
head_out = cache["z", layer][0, :, head, :] # [pos, d_head]
|
||||
W_O = model.W_O[layer, head] # [d_head, d_model]
|
||||
W_U = model.W_U # [d_model, vocab]
|
||||
|
||||
# Head contribution to logits at final position
|
||||
contribution = head_out[-1] @ W_O @ W_U
|
||||
return contribution[john_token] - contribution[mary_token]
|
||||
|
||||
# 4. Map all heads
|
||||
head_contributions = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)
|
||||
for layer in range(model.cfg.n_layers):
|
||||
for head in range(model.cfg.n_heads):
|
||||
head_contributions[layer, head] = get_head_contribution(layer, head)
|
||||
|
||||
# 5. Identify top contributing heads (name movers, backup name movers)
|
||||
```
|
||||
|
||||
### Checklist
|
||||
- [ ] Set up task with clear IO/S tokens
|
||||
- [ ] Compute baseline logit difference
|
||||
- [ ] Decompose by attention head contributions
|
||||
- [ ] Identify key circuit components (name movers, S-inhibition, induction)
|
||||
- [ ] Validate with ablation experiments
|
||||
|
||||
## Workflow 3: Induction Head Detection
|
||||
|
||||
Find induction heads that implement [A][B]...[A] → [B] pattern.
|
||||
|
||||
```python
|
||||
from transformer_lens import HookedTransformer
|
||||
import torch
|
||||
|
||||
model = HookedTransformer.from_pretrained("gpt2-small")
|
||||
|
||||
# Create repeated sequence: [A][B][A] should predict [B]
|
||||
repeated_tokens = torch.tensor([[1000, 2000, 1000]]) # Arbitrary tokens
|
||||
|
||||
_, cache = model.run_with_cache(repeated_tokens)
|
||||
|
||||
# Induction heads attend from final [A] back to first [B]
|
||||
# Check attention from position 2 to position 1
|
||||
induction_scores = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)
|
||||
|
||||
for layer in range(model.cfg.n_layers):
|
||||
pattern = cache["pattern", layer][0] # [head, q_pos, k_pos]
|
||||
# Attention from pos 2 to pos 1
|
||||
induction_scores[layer] = pattern[:, 2, 1]
|
||||
|
||||
# Heads with high scores are induction heads
|
||||
top_heads = torch.topk(induction_scores.flatten(), k=5)
|
||||
```
|
||||
|
||||
## Common Issues & Solutions
|
||||
|
||||
### Issue: Hooks persist after debugging
|
||||
```python
|
||||
# WRONG: Old hooks remain active
|
||||
model.run_with_hooks(tokens, fwd_hooks=[...]) # Debug, add new hooks
|
||||
model.run_with_hooks(tokens, fwd_hooks=[...]) # Old hooks still there!
|
||||
|
||||
# RIGHT: Always reset hooks
|
||||
model.reset_hooks()
|
||||
model.run_with_hooks(tokens, fwd_hooks=[...])
|
||||
```
|
||||
|
||||
### Issue: Tokenization gotchas
|
||||
```python
|
||||
# WRONG: Assuming consistent tokenization
|
||||
model.to_tokens("Tim") # Single token
|
||||
model.to_tokens("Neel") # Becomes "Ne" + "el" (two tokens!)
|
||||
|
||||
# RIGHT: Check tokenization explicitly
|
||||
tokens = model.to_tokens("Neel", prepend_bos=False)
|
||||
print(model.to_str_tokens(tokens)) # ['Ne', 'el']
|
||||
```
|
||||
|
||||
### Issue: LayerNorm ignored in analysis
|
||||
```python
|
||||
# WRONG: Ignoring LayerNorm
|
||||
pre_activation = residual @ model.W_in[layer]
|
||||
|
||||
# RIGHT: Include LayerNorm
|
||||
ln_scale = model.blocks[layer].ln2.w
|
||||
ln_out = model.blocks[layer].ln2(residual)
|
||||
pre_activation = ln_out @ model.W_in[layer]
|
||||
```
|
||||
|
||||
### Issue: Memory explosion with large models
|
||||
```python
|
||||
# Use selective caching
|
||||
logits, cache = model.run_with_cache(
|
||||
tokens,
|
||||
names_filter=lambda n: "resid_post" in n or "pattern" in n,
|
||||
device="cpu" # Cache on CPU
|
||||
)
|
||||
```
|
||||
|
||||
## Key Classes Reference
|
||||
|
||||
| Class | Purpose |
|
||||
|-------|---------|
|
||||
| `HookedTransformer` | Main model wrapper with hooks |
|
||||
| `ActivationCache` | Dictionary-like cache of activations |
|
||||
| `HookedTransformerConfig` | Model configuration |
|
||||
| `FactoredMatrix` | Efficient factored matrix operations |
|
||||
|
||||
## Integration with SAELens
|
||||
|
||||
TransformerLens integrates with SAELens for Sparse Autoencoder analysis:
|
||||
|
||||
```python
|
||||
from transformer_lens import HookedTransformer
|
||||
from sae_lens import SAE
|
||||
|
||||
model = HookedTransformer.from_pretrained("gpt2-small")
|
||||
sae = SAE.from_pretrained("gpt2-small-res-jb", "blocks.8.hook_resid_pre")
|
||||
|
||||
# Run with SAE
|
||||
tokens = model.to_tokens("Hello world")
|
||||
_, cache = model.run_with_cache(tokens)
|
||||
sae_acts = sae.encode(cache["resid_pre", 8])
|
||||
```
|
||||
|
||||
## Reference Documentation
|
||||
|
||||
For detailed API documentation, tutorials, and advanced usage, see the `references/` folder:
|
||||
|
||||
| File | Contents |
|
||||
|------|----------|
|
||||
| [references/README.md](references/README.md) | Overview and quick start guide |
|
||||
| [references/api.md](references/api.md) | Complete API reference for HookedTransformer, ActivationCache, HookPoints |
|
||||
| [references/tutorials.md](references/tutorials.md) | Step-by-step tutorials for activation patching, circuit analysis, logit lens |
|
||||
|
||||
## External Resources
|
||||
|
||||
### Tutorials
|
||||
- [Main Demo Notebook](https://transformerlensorg.github.io/TransformerLens/generated/demos/Main_Demo.html)
|
||||
- [Activation Patching Demo](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Activation_Patching_in_TL_Demo.ipynb)
|
||||
- [ARENA Mech Interp Course](https://arena-foundation.github.io/ARENA/) - 200+ hours of tutorials
|
||||
|
||||
### Papers
|
||||
- [A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html)
|
||||
- [In-context Learning and Induction Heads](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html)
|
||||
- [Interpretability in the Wild (IOI)](https://arxiv.org/abs/2211.00593)
|
||||
|
||||
### Official Documentation
|
||||
- [Official Docs](https://transformerlensorg.github.io/TransformerLens/)
|
||||
- [Model Properties Table](https://transformerlensorg.github.io/TransformerLens/generated/model_properties_table.html)
|
||||
- [Neel Nanda's Glossary](https://www.neelnanda.io/mechanistic-interpretability/glossary)
|
||||
|
||||
## Version Notes
|
||||
|
||||
- **v2.0**: Removed HookedSAE (moved to SAELens)
|
||||
- **v3.0 (alpha)**: TransformerBridge for loading any nn.Module
|
||||
+54
@@ -0,0 +1,54 @@
|
||||
# TransformerLens Reference Documentation
|
||||
|
||||
This directory contains comprehensive reference materials for TransformerLens.
|
||||
|
||||
## Contents
|
||||
|
||||
- [api.md](api.md) - Complete API reference for HookedTransformer, ActivationCache, and HookPoints
|
||||
- [tutorials.md](tutorials.md) - Step-by-step tutorials for common interpretability workflows
|
||||
- [papers.md](papers.md) - Key research papers and foundational concepts
|
||||
|
||||
## Quick Links
|
||||
|
||||
- **Official Documentation**: https://transformerlensorg.github.io/TransformerLens/
|
||||
- **GitHub Repository**: https://github.com/TransformerLensOrg/TransformerLens
|
||||
- **Model Properties Table**: https://transformerlensorg.github.io/TransformerLens/generated/model_properties_table.html
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install transformer-lens
|
||||
```
|
||||
|
||||
## Basic Usage
|
||||
|
||||
```python
|
||||
from transformer_lens import HookedTransformer
|
||||
|
||||
# Load model
|
||||
model = HookedTransformer.from_pretrained("gpt2-small")
|
||||
|
||||
# Run with activation caching
|
||||
tokens = model.to_tokens("Hello world")
|
||||
logits, cache = model.run_with_cache(tokens)
|
||||
|
||||
# Access activations
|
||||
residual = cache["resid_post", 5] # Layer 5 residual stream
|
||||
attention = cache["pattern", 3] # Layer 3 attention patterns
|
||||
```
|
||||
|
||||
## Key Concepts
|
||||
|
||||
### HookPoints
|
||||
Every activation in the transformer has a HookPoint wrapper, enabling:
|
||||
- Reading activations via `run_with_cache()`
|
||||
- Modifying activations via `run_with_hooks()`
|
||||
|
||||
### Activation Cache
|
||||
The `ActivationCache` stores all intermediate activations with helper methods for:
|
||||
- Residual stream decomposition
|
||||
- Logit attribution
|
||||
- Layer-wise analysis
|
||||
|
||||
### Supported Models (50+)
|
||||
GPT-2, LLaMA, Mistral, Pythia, GPT-Neo, OPT, Gemma, Phi, and more.
|
||||
@@ -0,0 +1,362 @@
|
||||
# TransformerLens API Reference
|
||||
|
||||
## HookedTransformer
|
||||
|
||||
The core class for mechanistic interpretability, wrapping transformer models with hooks on every activation.
|
||||
|
||||
### Loading Models
|
||||
|
||||
```python
|
||||
from transformer_lens import HookedTransformer
|
||||
|
||||
# Basic loading
|
||||
model = HookedTransformer.from_pretrained("gpt2-small")
|
||||
|
||||
# With specific device/dtype
|
||||
model = HookedTransformer.from_pretrained(
|
||||
"gpt2-medium",
|
||||
device="cuda",
|
||||
dtype=torch.float16
|
||||
)
|
||||
|
||||
# Gated models (LLaMA, Mistral)
|
||||
import os
|
||||
os.environ["HF_TOKEN"] = "your_token"
|
||||
model = HookedTransformer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
```
|
||||
|
||||
### from_pretrained() Parameters
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `model_name` | str | required | Model name from OFFICIAL_MODEL_NAMES |
|
||||
| `fold_ln` | bool | True | Fold LayerNorm weights into subsequent layers |
|
||||
| `center_writing_weights` | bool | True | Center residual stream writer means |
|
||||
| `center_unembed` | bool | True | Center unembedding weights |
|
||||
| `dtype` | torch.dtype | None | Model precision |
|
||||
| `device` | str | None | Target device |
|
||||
| `n_devices` | int | 1 | Number of devices for model parallelism |
|
||||
|
||||
### Weight Matrices
|
||||
|
||||
| Property | Shape | Description |
|
||||
|----------|-------|-------------|
|
||||
| `W_E` | [d_vocab, d_model] | Token embedding matrix |
|
||||
| `W_U` | [d_model, d_vocab] | Unembedding matrix |
|
||||
| `W_pos` | [n_ctx, d_model] | Positional embedding |
|
||||
| `W_Q` | [n_layers, n_heads, d_model, d_head] | Query weights |
|
||||
| `W_K` | [n_layers, n_heads, d_model, d_head] | Key weights |
|
||||
| `W_V` | [n_layers, n_heads, d_model, d_head] | Value weights |
|
||||
| `W_O` | [n_layers, n_heads, d_head, d_model] | Output weights |
|
||||
| `W_in` | [n_layers, d_model, d_mlp] | MLP input weights |
|
||||
| `W_out` | [n_layers, d_mlp, d_model] | MLP output weights |
|
||||
|
||||
### Core Methods
|
||||
|
||||
#### forward()
|
||||
|
||||
```python
|
||||
logits = model(tokens)
|
||||
logits = model(tokens, return_type="logits")
|
||||
loss = model(tokens, return_type="loss")
|
||||
logits, loss = model(tokens, return_type="both")
|
||||
```
|
||||
|
||||
Parameters:
|
||||
- `input`: Token tensor or string
|
||||
- `return_type`: "logits", "loss", "both", or None
|
||||
- `prepend_bos`: Whether to prepend BOS token
|
||||
- `start_at_layer`: Start execution from specific layer
|
||||
- `stop_at_layer`: Stop execution at specific layer
|
||||
|
||||
#### run_with_cache()
|
||||
|
||||
```python
|
||||
logits, cache = model.run_with_cache(tokens)
|
||||
|
||||
# Selective caching (saves memory)
|
||||
logits, cache = model.run_with_cache(
|
||||
tokens,
|
||||
names_filter=lambda name: "resid_post" in name
|
||||
)
|
||||
|
||||
# Cache on CPU
|
||||
logits, cache = model.run_with_cache(tokens, device="cpu")
|
||||
```
|
||||
|
||||
#### run_with_hooks()
|
||||
|
||||
```python
|
||||
def my_hook(activation, hook):
|
||||
# Modify activation
|
||||
activation[:, :, 0] = 0
|
||||
return activation
|
||||
|
||||
logits = model.run_with_hooks(
|
||||
tokens,
|
||||
fwd_hooks=[("blocks.5.hook_resid_post", my_hook)]
|
||||
)
|
||||
```
|
||||
|
||||
#### generate()
|
||||
|
||||
```python
|
||||
output = model.generate(
|
||||
tokens,
|
||||
max_new_tokens=50,
|
||||
temperature=0.7,
|
||||
top_k=40,
|
||||
top_p=0.9,
|
||||
freq_penalty=1.0,
|
||||
use_past_kv_cache=True
|
||||
)
|
||||
```
|
||||
|
||||
### Tokenization Methods
|
||||
|
||||
```python
|
||||
# String to tokens
|
||||
tokens = model.to_tokens("Hello world") # [1, seq_len]
|
||||
tokens = model.to_tokens("Hello", prepend_bos=False)
|
||||
|
||||
# Tokens to string
|
||||
text = model.to_string(tokens)
|
||||
|
||||
# Get string tokens (for debugging)
|
||||
str_tokens = model.to_str_tokens("Hello world")
|
||||
# ['<|endoftext|>', 'Hello', ' world']
|
||||
|
||||
# Single token validation
|
||||
token_id = model.to_single_token(" Paris") # Returns int or raises error
|
||||
```
|
||||
|
||||
### Hook Management
|
||||
|
||||
```python
|
||||
# Clear all hooks
|
||||
model.reset_hooks()
|
||||
|
||||
# Add permanent hook
|
||||
model.add_hook("blocks.0.hook_resid_post", my_hook)
|
||||
|
||||
# Remove specific hook
|
||||
model.remove_hook("blocks.0.hook_resid_post")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## ActivationCache
|
||||
|
||||
Stores and provides access to all activations from a forward pass.
|
||||
|
||||
### Accessing Activations
|
||||
|
||||
```python
|
||||
logits, cache = model.run_with_cache(tokens)
|
||||
|
||||
# By name and layer
|
||||
residual = cache["resid_post", 5]
|
||||
attention = cache["pattern", 3]
|
||||
mlp_out = cache["mlp_out", 7]
|
||||
|
||||
# Full name string
|
||||
residual = cache["blocks.5.hook_resid_post"]
|
||||
```
|
||||
|
||||
### Cache Keys
|
||||
|
||||
| Key Pattern | Shape | Description |
|
||||
|-------------|-------|-------------|
|
||||
| `hook_embed` | [batch, pos, d_model] | Token embeddings |
|
||||
| `hook_pos_embed` | [batch, pos, d_model] | Positional embeddings |
|
||||
| `resid_pre, layer` | [batch, pos, d_model] | Residual before attention |
|
||||
| `resid_mid, layer` | [batch, pos, d_model] | Residual after attention |
|
||||
| `resid_post, layer` | [batch, pos, d_model] | Residual after MLP |
|
||||
| `attn_out, layer` | [batch, pos, d_model] | Attention output |
|
||||
| `mlp_out, layer` | [batch, pos, d_model] | MLP output |
|
||||
| `pattern, layer` | [batch, head, q_pos, k_pos] | Attention pattern (post-softmax) |
|
||||
| `attn_scores, layer` | [batch, head, q_pos, k_pos] | Attention scores (pre-softmax) |
|
||||
| `q, layer` | [batch, pos, head, d_head] | Query vectors |
|
||||
| `k, layer` | [batch, pos, head, d_head] | Key vectors |
|
||||
| `v, layer` | [batch, pos, head, d_head] | Value vectors |
|
||||
| `z, layer` | [batch, pos, head, d_head] | Attention output per head |
|
||||
|
||||
### Analysis Methods
|
||||
|
||||
#### decompose_resid()
|
||||
|
||||
Decomposes residual stream into component contributions:
|
||||
|
||||
```python
|
||||
components, labels = cache.decompose_resid(
|
||||
layer=5,
|
||||
return_labels=True,
|
||||
mode="attn" # or "mlp" or "full"
|
||||
)
|
||||
```
|
||||
|
||||
#### accumulated_resid()
|
||||
|
||||
Get accumulated residual at each layer (for Logit Lens):
|
||||
|
||||
```python
|
||||
accumulated = cache.accumulated_resid(
|
||||
layer=None, # All layers
|
||||
incl_mid=False,
|
||||
apply_ln=True # Apply final LayerNorm
|
||||
)
|
||||
```
|
||||
|
||||
#### logit_attrs()
|
||||
|
||||
Calculate logit attribution for components:
|
||||
|
||||
```python
|
||||
attrs = cache.logit_attrs(
|
||||
residual_stack,
|
||||
tokens=target_tokens,
|
||||
incorrect_tokens=incorrect_tokens
|
||||
)
|
||||
```
|
||||
|
||||
#### stack_head_results()
|
||||
|
||||
Stack attention head outputs:
|
||||
|
||||
```python
|
||||
head_results = cache.stack_head_results(
|
||||
layer=-1, # All layers
|
||||
pos_slice=None # All positions
|
||||
)
|
||||
# Shape: [n_layers, n_heads, batch, pos, d_model]
|
||||
```
|
||||
|
||||
### Utility Methods
|
||||
|
||||
```python
|
||||
# Move cache to device
|
||||
cache = cache.to("cpu")
|
||||
|
||||
# Remove batch dimension (for batch_size=1)
|
||||
cache = cache.remove_batch_dim()
|
||||
|
||||
# Get all keys
|
||||
keys = cache.keys()
|
||||
|
||||
# Iterate
|
||||
for name, activation in cache.items():
|
||||
print(name, activation.shape)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## HookPoint
|
||||
|
||||
The fundamental hook mechanism wrapping every activation.
|
||||
|
||||
### Hook Function Signature
|
||||
|
||||
```python
|
||||
def hook_fn(activation: torch.Tensor, hook: HookPoint) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
activation: Current activation value
|
||||
hook: The HookPoint object (has .name attribute)
|
||||
|
||||
Returns:
|
||||
Modified activation (or None to keep original)
|
||||
"""
|
||||
# Modify activation
|
||||
return activation
|
||||
```
|
||||
|
||||
### Common Hook Patterns
|
||||
|
||||
```python
|
||||
# Zero ablation
|
||||
def zero_hook(act, hook):
|
||||
act[:, :, :] = 0
|
||||
return act
|
||||
|
||||
# Mean ablation
|
||||
def mean_hook(act, hook):
|
||||
act[:, :, :] = act.mean(dim=0, keepdim=True)
|
||||
return act
|
||||
|
||||
# Patch from cache
|
||||
def patch_hook(act, hook):
|
||||
act[:, 5, :] = clean_cache[hook.name][:, 5, :]
|
||||
return act
|
||||
|
||||
# Add steering vector
|
||||
def steer_hook(act, hook):
|
||||
act += 0.5 * steering_vector
|
||||
return act
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Utility Functions
|
||||
|
||||
### patching module
|
||||
|
||||
```python
|
||||
from transformer_lens import patching
|
||||
|
||||
# Generic activation patching
|
||||
results = patching.generic_activation_patch(
|
||||
model=model,
|
||||
corrupted_tokens=corrupted,
|
||||
clean_cache=clean_cache,
|
||||
patching_metric=metric_fn,
|
||||
patch_setter=patch_fn,
|
||||
activation_name="resid_post",
|
||||
index_axis_names=("layer", "pos")
|
||||
)
|
||||
```
|
||||
|
||||
### FactoredMatrix
|
||||
|
||||
Efficient operations on factored weight matrices:
|
||||
|
||||
```python
|
||||
from transformer_lens import FactoredMatrix
|
||||
|
||||
# QK circuit
|
||||
QK = FactoredMatrix(model.W_Q[layer], model.W_K[layer].T)
|
||||
|
||||
# OV circuit
|
||||
OV = FactoredMatrix(model.W_V[layer], model.W_O[layer])
|
||||
|
||||
# Get full matrix
|
||||
full = QK.AB
|
||||
|
||||
# SVD decomposition
|
||||
U, S, V = QK.svd()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Configuration
|
||||
|
||||
### HookedTransformerConfig
|
||||
|
||||
Key configuration attributes:
|
||||
|
||||
| Attribute | Description |
|
||||
|-----------|-------------|
|
||||
| `n_layers` | Number of transformer layers |
|
||||
| `n_heads` | Number of attention heads |
|
||||
| `d_model` | Model dimension |
|
||||
| `d_head` | Head dimension |
|
||||
| `d_mlp` | MLP hidden dimension |
|
||||
| `d_vocab` | Vocabulary size |
|
||||
| `n_ctx` | Maximum context length |
|
||||
| `act_fn` | Activation function name |
|
||||
| `normalization_type` | "LN" or "LNPre" |
|
||||
|
||||
Access via:
|
||||
```python
|
||||
model.cfg.n_layers
|
||||
model.cfg.d_model
|
||||
```
|
||||
+339
@@ -0,0 +1,339 @@
|
||||
# TransformerLens Tutorials
|
||||
|
||||
## Tutorial 1: Basic Activation Analysis
|
||||
|
||||
### Goal
|
||||
Understand how to load models, cache activations, and inspect model internals.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from transformer_lens import HookedTransformer
|
||||
import torch
|
||||
|
||||
# 1. Load model
|
||||
model = HookedTransformer.from_pretrained("gpt2-small")
|
||||
print(f"Model has {model.cfg.n_layers} layers, {model.cfg.n_heads} heads")
|
||||
|
||||
# 2. Tokenize input
|
||||
prompt = "The capital of France is"
|
||||
tokens = model.to_tokens(prompt)
|
||||
print(f"Tokens shape: {tokens.shape}")
|
||||
print(f"String tokens: {model.to_str_tokens(prompt)}")
|
||||
|
||||
# 3. Run with cache
|
||||
logits, cache = model.run_with_cache(tokens)
|
||||
print(f"Logits shape: {logits.shape}")
|
||||
print(f"Cache keys: {len(cache.keys())}")
|
||||
|
||||
# 4. Inspect activations
|
||||
for layer in range(model.cfg.n_layers):
|
||||
resid = cache["resid_post", layer]
|
||||
print(f"Layer {layer} residual norm: {resid.norm().item():.2f}")
|
||||
|
||||
# 5. Look at attention patterns
|
||||
attn = cache["pattern", 0] # Layer 0
|
||||
print(f"Attention shape: {attn.shape}") # [batch, heads, q_pos, k_pos]
|
||||
|
||||
# 6. Get top predictions
|
||||
probs = torch.softmax(logits[0, -1], dim=-1)
|
||||
top_tokens = probs.topk(5)
|
||||
for token_id, prob in zip(top_tokens.indices, top_tokens.values):
|
||||
print(f"{model.to_string(token_id.unsqueeze(0))}: {prob.item():.3f}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tutorial 2: Activation Patching
|
||||
|
||||
### Goal
|
||||
Identify which activations causally affect model output.
|
||||
|
||||
### Concept
|
||||
1. Run model on "clean" input, cache activations
|
||||
2. Run model on "corrupted" input
|
||||
3. Patch clean activations into corrupted run
|
||||
4. Measure effect on output
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from transformer_lens import HookedTransformer
|
||||
import torch
|
||||
|
||||
model = HookedTransformer.from_pretrained("gpt2-small")
|
||||
|
||||
# Define clean and corrupted prompts
|
||||
clean_prompt = "The Eiffel Tower is in the city of"
|
||||
corrupted_prompt = "The Colosseum is in the city of"
|
||||
|
||||
clean_tokens = model.to_tokens(clean_prompt)
|
||||
corrupted_tokens = model.to_tokens(corrupted_prompt)
|
||||
|
||||
# Get clean activations
|
||||
_, clean_cache = model.run_with_cache(clean_tokens)
|
||||
|
||||
# Define metric
|
||||
paris_token = model.to_single_token(" Paris")
|
||||
rome_token = model.to_single_token(" Rome")
|
||||
|
||||
def logit_diff(logits):
|
||||
"""Positive = model prefers Paris over Rome"""
|
||||
return (logits[0, -1, paris_token] - logits[0, -1, rome_token]).item()
|
||||
|
||||
# Baseline measurements
|
||||
clean_logits = model(clean_tokens)
|
||||
corrupted_logits = model(corrupted_tokens)
|
||||
print(f"Clean logit diff: {logit_diff(clean_logits):.3f}")
|
||||
print(f"Corrupted logit diff: {logit_diff(corrupted_logits):.3f}")
|
||||
|
||||
# Patch each layer
|
||||
results = []
|
||||
for layer in range(model.cfg.n_layers):
|
||||
def patch_hook(activation, hook, layer=layer):
|
||||
activation[:] = clean_cache["resid_post", layer]
|
||||
return activation
|
||||
|
||||
patched_logits = model.run_with_hooks(
|
||||
corrupted_tokens,
|
||||
fwd_hooks=[(f"blocks.{layer}.hook_resid_post", patch_hook)]
|
||||
)
|
||||
results.append(logit_diff(patched_logits))
|
||||
print(f"Layer {layer}: {results[-1]:.3f}")
|
||||
|
||||
# Find most important layer
|
||||
best_layer = max(range(len(results)), key=lambda i: results[i])
|
||||
print(f"\nMost important layer: {best_layer}")
|
||||
```
|
||||
|
||||
### Position-Specific Patching
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
seq_len = clean_tokens.shape[1]
|
||||
results = torch.zeros(model.cfg.n_layers, seq_len)
|
||||
|
||||
for layer in range(model.cfg.n_layers):
|
||||
for pos in range(seq_len):
|
||||
def patch_hook(activation, hook, layer=layer, pos=pos):
|
||||
activation[:, pos, :] = clean_cache["resid_post", layer][:, pos, :]
|
||||
return activation
|
||||
|
||||
patched_logits = model.run_with_hooks(
|
||||
corrupted_tokens,
|
||||
fwd_hooks=[(f"blocks.{layer}.hook_resid_post", patch_hook)]
|
||||
)
|
||||
results[layer, pos] = logit_diff(patched_logits)
|
||||
|
||||
# Visualize as heatmap
|
||||
import matplotlib.pyplot as plt
|
||||
plt.figure(figsize=(12, 8))
|
||||
plt.imshow(results.numpy(), aspect='auto', cmap='RdBu')
|
||||
plt.xlabel('Position')
|
||||
plt.ylabel('Layer')
|
||||
plt.colorbar(label='Logit Difference')
|
||||
plt.title('Activation Patching Results')
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tutorial 3: Direct Logit Attribution
|
||||
|
||||
### Goal
|
||||
Identify which components (heads, neurons) contribute to specific predictions.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from transformer_lens import HookedTransformer
|
||||
import torch
|
||||
|
||||
model = HookedTransformer.from_pretrained("gpt2-small")
|
||||
|
||||
prompt = "The capital of France is"
|
||||
tokens = model.to_tokens(prompt)
|
||||
logits, cache = model.run_with_cache(tokens)
|
||||
|
||||
# Target token
|
||||
target_token = model.to_single_token(" Paris")
|
||||
|
||||
# Get unembedding direction for target
|
||||
target_direction = model.W_U[:, target_token] # [d_model]
|
||||
|
||||
# Attribution per attention head
|
||||
head_contributions = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)
|
||||
|
||||
for layer in range(model.cfg.n_layers):
|
||||
# Get per-head output at final position
|
||||
z = cache["z", layer][0, -1] # [n_heads, d_head]
|
||||
|
||||
for head in range(model.cfg.n_heads):
|
||||
# Project through W_O to get contribution to residual
|
||||
head_out = z[head] @ model.W_O[layer, head] # [d_model]
|
||||
|
||||
# Dot with target direction
|
||||
contribution = (head_out @ target_direction).item()
|
||||
head_contributions[layer, head] = contribution
|
||||
|
||||
# Find top contributing heads
|
||||
flat_idx = head_contributions.flatten().topk(10)
|
||||
print("Top 10 heads for predicting 'Paris':")
|
||||
for idx, val in zip(flat_idx.indices, flat_idx.values):
|
||||
layer = idx.item() // model.cfg.n_heads
|
||||
head = idx.item() % model.cfg.n_heads
|
||||
print(f" L{layer}H{head}: {val.item():.3f}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tutorial 4: Induction Head Detection
|
||||
|
||||
### Goal
|
||||
Find attention heads that implement the [A][B]...[A] → [B] pattern.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from transformer_lens import HookedTransformer
|
||||
import torch
|
||||
|
||||
model = HookedTransformer.from_pretrained("gpt2-small")
|
||||
|
||||
# Create repeated sequence pattern
|
||||
# Pattern: [A][B][C][A] - model should attend from last A to B
|
||||
seq = torch.randint(1000, 5000, (1, 20))
|
||||
# Repeat first half
|
||||
seq[0, 10:] = seq[0, :10]
|
||||
|
||||
_, cache = model.run_with_cache(seq)
|
||||
|
||||
# For induction heads: position i should attend to position (i - seq_len/2 + 1)
|
||||
# At position 10 (second A), should attend to position 1 (first B)
|
||||
|
||||
induction_scores = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)
|
||||
|
||||
for layer in range(model.cfg.n_layers):
|
||||
pattern = cache["pattern", layer][0] # [heads, q_pos, k_pos]
|
||||
|
||||
# Check attention from repeated positions to position after first occurrence
|
||||
for offset in range(1, 10):
|
||||
q_pos = 10 + offset # Position in second half
|
||||
k_pos = offset # Should attend to corresponding position in first half
|
||||
|
||||
# Average attention to the "correct" position
|
||||
induction_scores[layer] += pattern[:, q_pos, k_pos]
|
||||
|
||||
induction_scores[layer] /= 9 # Average over offsets
|
||||
|
||||
# Find top induction heads
|
||||
print("Top induction heads:")
|
||||
for layer in range(model.cfg.n_layers):
|
||||
for head in range(model.cfg.n_heads):
|
||||
score = induction_scores[layer, head].item()
|
||||
if score > 0.3:
|
||||
print(f" L{layer}H{head}: {score:.3f}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tutorial 5: Logit Lens
|
||||
|
||||
### Goal
|
||||
See what the model "believes" at each layer before final unembedding.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from transformer_lens import HookedTransformer
|
||||
import torch
|
||||
|
||||
model = HookedTransformer.from_pretrained("gpt2-small")
|
||||
|
||||
prompt = "The quick brown fox jumps over the lazy"
|
||||
tokens = model.to_tokens(prompt)
|
||||
logits, cache = model.run_with_cache(tokens)
|
||||
|
||||
# Get accumulated residual at each layer
|
||||
# Apply LayerNorm to match what unembedding sees
|
||||
accumulated = cache.accumulated_resid(layer=None, incl_mid=False, apply_ln=True)
|
||||
# Shape: [n_layers + 1, batch, pos, d_model]
|
||||
|
||||
# Project to vocabulary
|
||||
layer_logits = accumulated @ model.W_U # [n_layers + 1, batch, pos, d_vocab]
|
||||
|
||||
# Look at predictions for final position
|
||||
print("Layer-by-layer predictions for final token:")
|
||||
for layer in range(model.cfg.n_layers + 1):
|
||||
probs = torch.softmax(layer_logits[layer, 0, -1], dim=-1)
|
||||
top_token = probs.argmax()
|
||||
top_prob = probs[top_token].item()
|
||||
print(f"Layer {layer}: {model.to_string(top_token.unsqueeze(0))!r} ({top_prob:.3f})")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tutorial 6: Steering with Activation Addition
|
||||
|
||||
### Goal
|
||||
Add a steering vector to change model behavior.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from transformer_lens import HookedTransformer
|
||||
import torch
|
||||
|
||||
model = HookedTransformer.from_pretrained("gpt2-small")
|
||||
|
||||
# Get activations for contrasting prompts
|
||||
positive_prompt = "I love this! It's absolutely wonderful and"
|
||||
negative_prompt = "I hate this! It's absolutely terrible and"
|
||||
|
||||
_, pos_cache = model.run_with_cache(model.to_tokens(positive_prompt))
|
||||
_, neg_cache = model.run_with_cache(model.to_tokens(negative_prompt))
|
||||
|
||||
# Compute steering vector (positive - negative direction)
|
||||
layer = 6
|
||||
steering_vector = (
|
||||
pos_cache["resid_post", layer].mean(dim=1) -
|
||||
neg_cache["resid_post", layer].mean(dim=1)
|
||||
)
|
||||
|
||||
# Generate with steering
|
||||
test_prompt = "The movie was"
|
||||
test_tokens = model.to_tokens(test_prompt)
|
||||
|
||||
def steer_hook(activation, hook):
|
||||
activation += 2.0 * steering_vector
|
||||
return activation
|
||||
|
||||
# Without steering
|
||||
normal_output = model.generate(test_tokens, max_new_tokens=20)
|
||||
print(f"Normal: {model.to_string(normal_output[0])}")
|
||||
|
||||
# With positive steering
|
||||
steered_output = model.generate(
|
||||
test_tokens,
|
||||
max_new_tokens=20,
|
||||
fwd_hooks=[(f"blocks.{layer}.hook_resid_post", steer_hook)]
|
||||
)
|
||||
print(f"Steered: {model.to_string(steered_output[0])}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## External Resources
|
||||
|
||||
### Official Tutorials
|
||||
- [Main Demo](https://transformerlensorg.github.io/TransformerLens/generated/demos/Main_Demo.html)
|
||||
- [Exploratory Analysis](https://transformerlensorg.github.io/TransformerLens/generated/demos/Exploratory_Analysis_Demo.html)
|
||||
- [Activation Patching Demo](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Activation_Patching_in_TL_Demo.ipynb)
|
||||
|
||||
### ARENA Course
|
||||
Comprehensive 200+ hour curriculum: https://arena-foundation.github.io/ARENA/
|
||||
|
||||
### Neel Nanda's Resources
|
||||
- [Getting Started in Mech Interp](https://www.neelnanda.io/mechanistic-interpretability/getting-started)
|
||||
- [Mech Interp Glossary](https://www.neelnanda.io/mechanistic-interpretability/glossary)
|
||||
- [YouTube Channel](https://www.youtube.com/@neelnanda)
|
||||
@@ -0,0 +1,5 @@
|
||||
# Skills Coming Soon
|
||||
|
||||
This directory will contain high-quality AI research skills for data processing.
|
||||
|
||||
See [CONTRIBUTING.md](../CONTRIBUTING.md) for how to contribute.
|
||||
@@ -0,0 +1,383 @@
|
||||
---
|
||||
name: nemo-curator
|
||||
description: GPU-accelerated data curation for LLM training. Supports text/image/video/audio. Features fuzzy deduplication (16× faster), quality filtering (30+ heuristics), semantic deduplication, PII redaction, NSFW detection. Scales across GPUs with RAPIDS. Use for preparing high-quality training datasets, cleaning web data, or deduplicating large corpora.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Data Processing, NeMo Curator, Data Curation, GPU Acceleration, Deduplication, Quality Filtering, NVIDIA, RAPIDS, PII Redaction, Multimodal, LLM Training Data]
|
||||
dependencies: [nemo-curator, cudf, dask, rapids]
|
||||
---
|
||||
|
||||
# NeMo Curator - GPU-Accelerated Data Curation
|
||||
|
||||
NVIDIA's toolkit for preparing high-quality training data for LLMs.
|
||||
|
||||
## When to use NeMo Curator
|
||||
|
||||
**Use NeMo Curator when:**
|
||||
- Preparing LLM training data from web scrapes (Common Crawl)
|
||||
- Need fast deduplication (16× faster than CPU)
|
||||
- Curating multi-modal datasets (text, images, video, audio)
|
||||
- Filtering low-quality or toxic content
|
||||
- Scaling data processing across GPU cluster
|
||||
|
||||
**Performance**:
|
||||
- **16× faster** fuzzy deduplication (8TB RedPajama v2)
|
||||
- **40% lower TCO** vs CPU alternatives
|
||||
- **Near-linear scaling** across GPU nodes
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **datatrove**: CPU-based, open-source data processing
|
||||
- **dolma**: Allen AI's data toolkit
|
||||
- **Ray Data**: General ML data processing (no curation focus)
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Text curation (CUDA 12)
|
||||
uv pip install "nemo-curator[text_cuda12]"
|
||||
|
||||
# All modalities
|
||||
uv pip install "nemo-curator[all_cuda12]"
|
||||
|
||||
# CPU-only (slower)
|
||||
uv pip install "nemo-curator[cpu]"
|
||||
```
|
||||
|
||||
### Basic text curation pipeline
|
||||
|
||||
```python
|
||||
from nemo_curator import ScoreFilter, Modify
|
||||
from nemo_curator.datasets import DocumentDataset
|
||||
import pandas as pd
|
||||
|
||||
# Load data
|
||||
df = pd.DataFrame({"text": ["Good document", "Bad doc", "Excellent text"]})
|
||||
dataset = DocumentDataset(df)
|
||||
|
||||
# Quality filtering
|
||||
def quality_score(doc):
|
||||
return len(doc["text"].split()) > 5 # Filter short docs
|
||||
|
||||
filtered = ScoreFilter(quality_score)(dataset)
|
||||
|
||||
# Deduplication
|
||||
from nemo_curator.modules import ExactDuplicates
|
||||
deduped = ExactDuplicates()(filtered)
|
||||
|
||||
# Save
|
||||
deduped.to_parquet("curated_data/")
|
||||
```
|
||||
|
||||
## Data curation pipeline
|
||||
|
||||
### Stage 1: Quality filtering
|
||||
|
||||
```python
|
||||
from nemo_curator.filters import (
|
||||
WordCountFilter,
|
||||
RepeatedLinesFilter,
|
||||
UrlRatioFilter,
|
||||
NonAlphaNumericFilter
|
||||
)
|
||||
|
||||
# Apply 30+ heuristic filters
|
||||
from nemo_curator import ScoreFilter
|
||||
|
||||
# Word count filter
|
||||
dataset = dataset.filter(WordCountFilter(min_words=50, max_words=100000))
|
||||
|
||||
# Remove repetitive content
|
||||
dataset = dataset.filter(RepeatedLinesFilter(max_repeated_line_fraction=0.3))
|
||||
|
||||
# URL ratio filter
|
||||
dataset = dataset.filter(UrlRatioFilter(max_url_ratio=0.2))
|
||||
```
|
||||
|
||||
### Stage 2: Deduplication
|
||||
|
||||
**Exact deduplication**:
|
||||
```python
|
||||
from nemo_curator.modules import ExactDuplicates
|
||||
|
||||
# Remove exact duplicates
|
||||
deduped = ExactDuplicates(id_field="id", text_field="text")(dataset)
|
||||
```
|
||||
|
||||
**Fuzzy deduplication** (16× faster on GPU):
|
||||
```python
|
||||
from nemo_curator.modules import FuzzyDuplicates
|
||||
|
||||
# MinHash + LSH deduplication
|
||||
fuzzy_dedup = FuzzyDuplicates(
|
||||
id_field="id",
|
||||
text_field="text",
|
||||
num_hashes=260, # MinHash parameters
|
||||
num_buckets=20,
|
||||
hash_method="md5"
|
||||
)
|
||||
|
||||
deduped = fuzzy_dedup(dataset)
|
||||
```
|
||||
|
||||
**Semantic deduplication**:
|
||||
```python
|
||||
from nemo_curator.modules import SemanticDuplicates
|
||||
|
||||
# Embedding-based deduplication
|
||||
semantic_dedup = SemanticDuplicates(
|
||||
id_field="id",
|
||||
text_field="text",
|
||||
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
threshold=0.8 # Cosine similarity threshold
|
||||
)
|
||||
|
||||
deduped = semantic_dedup(dataset)
|
||||
```
|
||||
|
||||
### Stage 3: PII redaction
|
||||
|
||||
```python
|
||||
from nemo_curator.modules import Modify
|
||||
from nemo_curator.modifiers import PIIRedactor
|
||||
|
||||
# Redact personally identifiable information
|
||||
pii_redactor = PIIRedactor(
|
||||
supported_entities=["EMAIL_ADDRESS", "PHONE_NUMBER", "PERSON", "LOCATION"],
|
||||
anonymize_action="replace" # or "redact"
|
||||
)
|
||||
|
||||
redacted = Modify(pii_redactor)(dataset)
|
||||
```
|
||||
|
||||
### Stage 4: Classifier filtering
|
||||
|
||||
```python
|
||||
from nemo_curator.classifiers import QualityClassifier
|
||||
|
||||
# Quality classification
|
||||
quality_clf = QualityClassifier(
|
||||
model_path="nvidia/quality-classifier-deberta",
|
||||
batch_size=256,
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# Filter low-quality documents
|
||||
high_quality = dataset.filter(lambda doc: quality_clf(doc["text"]) > 0.5)
|
||||
```
|
||||
|
||||
## GPU acceleration
|
||||
|
||||
### GPU vs CPU performance
|
||||
|
||||
| Operation | CPU (16 cores) | GPU (A100) | Speedup |
|
||||
|-----------|----------------|------------|---------|
|
||||
| Fuzzy dedup (8TB) | 120 hours | 7.5 hours | 16× |
|
||||
| Exact dedup (1TB) | 8 hours | 0.5 hours | 16× |
|
||||
| Quality filtering | 2 hours | 0.2 hours | 10× |
|
||||
|
||||
### Multi-GPU scaling
|
||||
|
||||
```python
|
||||
from nemo_curator import get_client
|
||||
import dask_cuda
|
||||
|
||||
# Initialize GPU cluster
|
||||
client = get_client(cluster_type="gpu", n_workers=8)
|
||||
|
||||
# Process with 8 GPUs
|
||||
deduped = FuzzyDuplicates(...)(dataset)
|
||||
```
|
||||
|
||||
## Multi-modal curation
|
||||
|
||||
### Image curation
|
||||
|
||||
```python
|
||||
from nemo_curator.image import (
|
||||
AestheticFilter,
|
||||
NSFWFilter,
|
||||
CLIPEmbedder
|
||||
)
|
||||
|
||||
# Aesthetic scoring
|
||||
aesthetic_filter = AestheticFilter(threshold=5.0)
|
||||
filtered_images = aesthetic_filter(image_dataset)
|
||||
|
||||
# NSFW detection
|
||||
nsfw_filter = NSFWFilter(threshold=0.9)
|
||||
safe_images = nsfw_filter(filtered_images)
|
||||
|
||||
# Generate CLIP embeddings
|
||||
clip_embedder = CLIPEmbedder(model="openai/clip-vit-base-patch32")
|
||||
image_embeddings = clip_embedder(safe_images)
|
||||
```
|
||||
|
||||
### Video curation
|
||||
|
||||
```python
|
||||
from nemo_curator.video import (
|
||||
SceneDetector,
|
||||
ClipExtractor,
|
||||
InternVideo2Embedder
|
||||
)
|
||||
|
||||
# Detect scenes
|
||||
scene_detector = SceneDetector(threshold=27.0)
|
||||
scenes = scene_detector(video_dataset)
|
||||
|
||||
# Extract clips
|
||||
clip_extractor = ClipExtractor(min_duration=2.0, max_duration=10.0)
|
||||
clips = clip_extractor(scenes)
|
||||
|
||||
# Generate embeddings
|
||||
video_embedder = InternVideo2Embedder()
|
||||
video_embeddings = video_embedder(clips)
|
||||
```
|
||||
|
||||
### Audio curation
|
||||
|
||||
```python
|
||||
from nemo_curator.audio import (
|
||||
ASRInference,
|
||||
WERFilter,
|
||||
DurationFilter
|
||||
)
|
||||
|
||||
# ASR transcription
|
||||
asr = ASRInference(model="nvidia/stt_en_fastconformer_hybrid_large_pc")
|
||||
transcribed = asr(audio_dataset)
|
||||
|
||||
# Filter by WER (word error rate)
|
||||
wer_filter = WERFilter(max_wer=0.3)
|
||||
high_quality_audio = wer_filter(transcribed)
|
||||
|
||||
# Duration filtering
|
||||
duration_filter = DurationFilter(min_duration=1.0, max_duration=30.0)
|
||||
filtered_audio = duration_filter(high_quality_audio)
|
||||
```
|
||||
|
||||
## Common patterns
|
||||
|
||||
### Web scrape curation (Common Crawl)
|
||||
|
||||
```python
|
||||
from nemo_curator import ScoreFilter, Modify
|
||||
from nemo_curator.filters import *
|
||||
from nemo_curator.modules import *
|
||||
from nemo_curator.datasets import DocumentDataset
|
||||
|
||||
# Load Common Crawl data
|
||||
dataset = DocumentDataset.read_parquet("common_crawl/*.parquet")
|
||||
|
||||
# Pipeline
|
||||
pipeline = [
|
||||
# 1. Quality filtering
|
||||
WordCountFilter(min_words=100, max_words=50000),
|
||||
RepeatedLinesFilter(max_repeated_line_fraction=0.2),
|
||||
SymbolToWordRatioFilter(max_symbol_to_word_ratio=0.3),
|
||||
UrlRatioFilter(max_url_ratio=0.3),
|
||||
|
||||
# 2. Language filtering
|
||||
LanguageIdentificationFilter(target_languages=["en"]),
|
||||
|
||||
# 3. Deduplication
|
||||
ExactDuplicates(id_field="id", text_field="text"),
|
||||
FuzzyDuplicates(id_field="id", text_field="text", num_hashes=260),
|
||||
|
||||
# 4. PII redaction
|
||||
PIIRedactor(),
|
||||
|
||||
# 5. NSFW filtering
|
||||
NSFWClassifier(threshold=0.8)
|
||||
]
|
||||
|
||||
# Execute
|
||||
for stage in pipeline:
|
||||
dataset = stage(dataset)
|
||||
|
||||
# Save
|
||||
dataset.to_parquet("curated_common_crawl/")
|
||||
```
|
||||
|
||||
### Distributed processing
|
||||
|
||||
```python
|
||||
from nemo_curator import get_client
|
||||
from dask_cuda import LocalCUDACluster
|
||||
|
||||
# Multi-GPU cluster
|
||||
cluster = LocalCUDACluster(n_workers=8)
|
||||
client = get_client(cluster=cluster)
|
||||
|
||||
# Process large dataset
|
||||
dataset = DocumentDataset.read_parquet("s3://large_dataset/*.parquet")
|
||||
deduped = FuzzyDuplicates(...)(dataset)
|
||||
|
||||
# Cleanup
|
||||
client.close()
|
||||
cluster.close()
|
||||
```
|
||||
|
||||
## Performance benchmarks
|
||||
|
||||
### Fuzzy deduplication (8TB RedPajama v2)
|
||||
|
||||
- **CPU (256 cores)**: 120 hours
|
||||
- **GPU (8× A100)**: 7.5 hours
|
||||
- **Speedup**: 16×
|
||||
|
||||
### Exact deduplication (1TB)
|
||||
|
||||
- **CPU (64 cores)**: 8 hours
|
||||
- **GPU (4× A100)**: 0.5 hours
|
||||
- **Speedup**: 16×
|
||||
|
||||
### Quality filtering (100GB)
|
||||
|
||||
- **CPU (32 cores)**: 2 hours
|
||||
- **GPU (2× A100)**: 0.2 hours
|
||||
- **Speedup**: 10×
|
||||
|
||||
## Cost comparison
|
||||
|
||||
**CPU-based curation** (AWS c5.18xlarge × 10):
|
||||
- Cost: $3.60/hour × 10 = $36/hour
|
||||
- Time for 8TB: 120 hours
|
||||
- **Total**: $4,320
|
||||
|
||||
**GPU-based curation** (AWS p4d.24xlarge × 2):
|
||||
- Cost: $32.77/hour × 2 = $65.54/hour
|
||||
- Time for 8TB: 7.5 hours
|
||||
- **Total**: $491.55
|
||||
|
||||
**Savings**: 89% reduction ($3,828 saved)
|
||||
|
||||
## Supported data formats
|
||||
|
||||
- **Input**: Parquet, JSONL, CSV
|
||||
- **Output**: Parquet (recommended), JSONL
|
||||
- **WebDataset**: TAR archives for multi-modal
|
||||
|
||||
## Use cases
|
||||
|
||||
**Production deployments**:
|
||||
- NVIDIA used NeMo Curator to prepare Nemotron-4 training data
|
||||
- Open-source datasets curated: RedPajama v2, The Pile
|
||||
|
||||
## References
|
||||
|
||||
- **[Filtering Guide](references/filtering.md)** - 30+ quality filters, heuristics
|
||||
- **[Deduplication Guide](references/deduplication.md)** - Exact, fuzzy, semantic methods
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/NVIDIA/NeMo-Curator ⭐ 500+
|
||||
- **Docs**: https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/
|
||||
- **Version**: 0.4.0+
|
||||
- **License**: Apache 2.0
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
# Deduplication Guide
|
||||
|
||||
Complete guide to exact, fuzzy, and semantic deduplication.
|
||||
|
||||
## Exact deduplication
|
||||
|
||||
Remove documents with identical content.
|
||||
|
||||
```python
|
||||
from nemo_curator.modules import ExactDuplicates
|
||||
|
||||
# Exact deduplication
|
||||
exact_dedup = ExactDuplicates(
|
||||
id_field="id",
|
||||
text_field="text",
|
||||
hash_method="md5" # or "sha256"
|
||||
)
|
||||
|
||||
deduped = exact_dedup(dataset)
|
||||
```
|
||||
|
||||
**Performance**: ~16× faster on GPU vs CPU
|
||||
|
||||
## Fuzzy deduplication
|
||||
|
||||
Remove near-duplicate documents using MinHash + LSH.
|
||||
|
||||
```python
|
||||
from nemo_curator.modules import FuzzyDuplicates
|
||||
|
||||
fuzzy_dedup = FuzzyDuplicates(
|
||||
id_field="id",
|
||||
text_field="text",
|
||||
num_hashes=260, # MinHash permutations (more = accurate)
|
||||
num_buckets=20, # LSH buckets (more = faster, less recall)
|
||||
hash_method="md5",
|
||||
jaccard_threshold=0.8 # Similarity threshold
|
||||
)
|
||||
|
||||
deduped = fuzzy_dedup(dataset)
|
||||
```
|
||||
|
||||
**Parameters**:
|
||||
- `num_hashes`: 128-512 (default 260)
|
||||
- `num_buckets`: 10-50 (default 20)
|
||||
- `jaccard_threshold`: 0.7-0.9 (default 0.8)
|
||||
|
||||
**Performance**: 16× faster on 8TB dataset (120h → 7.5h)
|
||||
|
||||
## Semantic deduplication
|
||||
|
||||
Remove semantically similar documents using embeddings.
|
||||
|
||||
```python
|
||||
from nemo_curator.modules import SemanticDuplicates
|
||||
|
||||
semantic_dedup = SemanticDuplicates(
|
||||
id_field="id",
|
||||
text_field="text",
|
||||
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
embedding_batch_size=256,
|
||||
threshold=0.85, # Cosine similarity threshold
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
deduped = semantic_dedup(dataset)
|
||||
```
|
||||
|
||||
**Models**:
|
||||
- `all-MiniLM-L6-v2`: Fast, 384 dims
|
||||
- `all-mpnet-base-v2`: Better quality, 768 dims
|
||||
- Custom models supported
|
||||
|
||||
## Comparison
|
||||
|
||||
| Method | Speed | Recall | Use Case |
|
||||
|--------|-------|--------|----------|
|
||||
| Exact | Fastest | 100% | Exact matches only |
|
||||
| Fuzzy | Fast | ~95% | Near-duplicates (recommended) |
|
||||
| Semantic | Slow | ~90% | Paraphrases, rewrites |
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Start with exact dedup** - Remove obvious duplicates
|
||||
2. **Use fuzzy for large datasets** - Best speed/quality trade-off
|
||||
3. **Semantic for high-value data** - Expensive but thorough
|
||||
4. **GPU acceleration required** - 10-16× speedup
|
||||
@@ -0,0 +1,102 @@
|
||||
# Quality Filtering Guide
|
||||
|
||||
Complete guide to NeMo Curator's 30+ quality filters.
|
||||
|
||||
## Text-based filters
|
||||
|
||||
### Word count
|
||||
|
||||
```python
|
||||
from nemo_curator.filters import WordCountFilter
|
||||
|
||||
# Filter by word count
|
||||
dataset = dataset.filter(WordCountFilter(min_words=50, max_words=100000))
|
||||
```
|
||||
|
||||
### Repeated content
|
||||
|
||||
```python
|
||||
from nemo_curator.filters import RepeatedLinesFilter
|
||||
|
||||
# Remove documents with >30% repeated lines
|
||||
dataset = dataset.filter(RepeatedLinesFilter(max_repeated_line_fraction=0.3))
|
||||
```
|
||||
|
||||
### Symbol ratio
|
||||
|
||||
```python
|
||||
from nemo_curator.filters import SymbolToWordRatioFilter
|
||||
|
||||
# Remove documents with too many symbols
|
||||
dataset = dataset.filter(SymbolToWordRatioFilter(max_symbol_to_word_ratio=0.3))
|
||||
```
|
||||
|
||||
### URL ratio
|
||||
|
||||
```python
|
||||
from nemo_curator.filters import UrlRatioFilter
|
||||
|
||||
# Remove documents with many URLs
|
||||
dataset = dataset.filter(UrlRatioFilter(max_url_ratio=0.2))
|
||||
```
|
||||
|
||||
## Language filtering
|
||||
|
||||
```python
|
||||
from nemo_curator.filters import LanguageIdentificationFilter
|
||||
|
||||
# Keep only English documents
|
||||
dataset = dataset.filter(LanguageIdentificationFilter(target_languages=["en"]))
|
||||
|
||||
# Multiple languages
|
||||
dataset = dataset.filter(LanguageIdentificationFilter(target_languages=["en", "es", "fr"]))
|
||||
```
|
||||
|
||||
## Classifier-based filtering
|
||||
|
||||
### Quality classifier
|
||||
|
||||
```python
|
||||
from nemo_curator.classifiers import QualityClassifier
|
||||
|
||||
quality_clf = QualityClassifier(
|
||||
model_path="nvidia/quality-classifier-deberta",
|
||||
batch_size=256,
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# Filter low-quality (threshold > 0.5 = high quality)
|
||||
dataset = dataset.filter(lambda doc: quality_clf(doc["text"]) > 0.5)
|
||||
```
|
||||
|
||||
### NSFW classifier
|
||||
|
||||
```python
|
||||
from nemo_curator.classifiers import NSFWClassifier
|
||||
|
||||
nsfw_clf = NSFWClassifier(threshold=0.9, device="cuda")
|
||||
|
||||
# Remove NSFW content
|
||||
dataset = dataset.filter(lambda doc: nsfw_clf(doc["text"]) < 0.9)
|
||||
```
|
||||
|
||||
## Heuristic filters
|
||||
|
||||
Full list of 30+ filters:
|
||||
- WordCountFilter
|
||||
- RepeatedLinesFilter
|
||||
- UrlRatioFilter
|
||||
- SymbolToWordRatioFilter
|
||||
- NonAlphaNumericFilter
|
||||
- BulletsFilter
|
||||
- WhiteSpaceFilter
|
||||
- ParenthesesFilter
|
||||
- LongWordFilter
|
||||
- And 20+ more...
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Apply cheap filters first** - Word count before GPU classifiers
|
||||
2. **Tune thresholds on sample** - Test on 10k docs before full run
|
||||
3. **Use GPU classifiers sparingly** - Expensive but effective
|
||||
4. **Chain filters efficiently** - Order by cost (cheap → expensive)
|
||||
@@ -0,0 +1,326 @@
|
||||
---
|
||||
name: ray-data
|
||||
description: Scalable data processing for ML workloads. Streaming execution across CPU/GPU, supports Parquet/CSV/JSON/images. Integrates with Ray Train, PyTorch, TensorFlow. Scales from single machine to 100s of nodes. Use for batch inference, data preprocessing, multi-modal data loading, or distributed ETL pipelines.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Data Processing, Ray Data, Distributed Computing, ML Pipelines, Batch Inference, ETL, Scalable, Ray, PyTorch, TensorFlow]
|
||||
dependencies: ["ray[data]", pyarrow, pandas]
|
||||
---
|
||||
|
||||
# Ray Data - Scalable ML Data Processing
|
||||
|
||||
Distributed data processing library for ML and AI workloads.
|
||||
|
||||
## When to use Ray Data
|
||||
|
||||
**Use Ray Data when:**
|
||||
- Processing large datasets (>100GB) for ML training
|
||||
- Need distributed data preprocessing across cluster
|
||||
- Building batch inference pipelines
|
||||
- Loading multi-modal data (images, audio, video)
|
||||
- Scaling data processing from laptop to cluster
|
||||
|
||||
**Key features**:
|
||||
- **Streaming execution**: Process data larger than memory
|
||||
- **GPU support**: Accelerate transforms with GPUs
|
||||
- **Framework integration**: PyTorch, TensorFlow, HuggingFace
|
||||
- **Multi-modal**: Images, Parquet, CSV, JSON, audio, video
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **Pandas**: Small data (<1GB) on single machine
|
||||
- **Dask**: Tabular data, SQL-like operations
|
||||
- **Spark**: Enterprise ETL, SQL queries
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
pip install -U 'ray[data]'
|
||||
```
|
||||
|
||||
### Load and transform data
|
||||
|
||||
```python
|
||||
import ray
|
||||
|
||||
# Read Parquet files
|
||||
ds = ray.data.read_parquet("s3://bucket/data/*.parquet")
|
||||
|
||||
# Transform data (lazy execution)
|
||||
ds = ds.map_batches(lambda batch: {"processed": batch["text"].str.lower()})
|
||||
|
||||
# Consume data
|
||||
for batch in ds.iter_batches(batch_size=100):
|
||||
print(batch)
|
||||
```
|
||||
|
||||
### Integration with Ray Train
|
||||
|
||||
```python
|
||||
import ray
|
||||
from ray.train import ScalingConfig
|
||||
from ray.train.torch import TorchTrainer
|
||||
|
||||
# Create dataset
|
||||
train_ds = ray.data.read_parquet("s3://bucket/train/*.parquet")
|
||||
|
||||
def train_func(config):
|
||||
# Access dataset in training
|
||||
train_ds = ray.train.get_dataset_shard("train")
|
||||
|
||||
for epoch in range(10):
|
||||
for batch in train_ds.iter_batches(batch_size=32):
|
||||
# Train on batch
|
||||
pass
|
||||
|
||||
# Train with Ray
|
||||
trainer = TorchTrainer(
|
||||
train_func,
|
||||
datasets={"train": train_ds},
|
||||
scaling_config=ScalingConfig(num_workers=4, use_gpu=True)
|
||||
)
|
||||
trainer.fit()
|
||||
```
|
||||
|
||||
## Reading data
|
||||
|
||||
### From cloud storage
|
||||
|
||||
```python
|
||||
import ray
|
||||
|
||||
# Parquet (recommended for ML)
|
||||
ds = ray.data.read_parquet("s3://bucket/data/*.parquet")
|
||||
|
||||
# CSV
|
||||
ds = ray.data.read_csv("s3://bucket/data/*.csv")
|
||||
|
||||
# JSON
|
||||
ds = ray.data.read_json("gs://bucket/data/*.json")
|
||||
|
||||
# Images
|
||||
ds = ray.data.read_images("s3://bucket/images/")
|
||||
```
|
||||
|
||||
### From Python objects
|
||||
|
||||
```python
|
||||
# From list
|
||||
ds = ray.data.from_items([{"id": i, "value": i * 2} for i in range(1000)])
|
||||
|
||||
# From range
|
||||
ds = ray.data.range(1000000) # Synthetic data
|
||||
|
||||
# From pandas
|
||||
import pandas as pd
|
||||
df = pd.DataFrame({"col1": [1, 2, 3], "col2": [4, 5, 6]})
|
||||
ds = ray.data.from_pandas(df)
|
||||
```
|
||||
|
||||
## Transformations
|
||||
|
||||
### Map batches (vectorized)
|
||||
|
||||
```python
|
||||
# Batch transformation (fast)
|
||||
def process_batch(batch):
|
||||
batch["doubled"] = batch["value"] * 2
|
||||
return batch
|
||||
|
||||
ds = ds.map_batches(process_batch, batch_size=1000)
|
||||
```
|
||||
|
||||
### Row transformations
|
||||
|
||||
```python
|
||||
# Row-by-row (slower)
|
||||
def process_row(row):
|
||||
row["squared"] = row["value"] ** 2
|
||||
return row
|
||||
|
||||
ds = ds.map(process_row)
|
||||
```
|
||||
|
||||
### Filter
|
||||
|
||||
```python
|
||||
# Filter rows
|
||||
ds = ds.filter(lambda row: row["value"] > 100)
|
||||
```
|
||||
|
||||
### Group by and aggregate
|
||||
|
||||
```python
|
||||
# Group by column
|
||||
ds = ds.groupby("category").count()
|
||||
|
||||
# Custom aggregation
|
||||
ds = ds.groupby("category").map_groups(lambda group: {"sum": group["value"].sum()})
|
||||
```
|
||||
|
||||
## GPU-accelerated transforms
|
||||
|
||||
```python
|
||||
# Use GPU for preprocessing
|
||||
def preprocess_images_gpu(batch):
|
||||
import torch
|
||||
images = torch.tensor(batch["image"]).cuda()
|
||||
# GPU preprocessing
|
||||
processed = images * 255
|
||||
return {"processed": processed.cpu().numpy()}
|
||||
|
||||
ds = ds.map_batches(
|
||||
preprocess_images_gpu,
|
||||
batch_size=64,
|
||||
num_gpus=1 # Request GPU
|
||||
)
|
||||
```
|
||||
|
||||
## Writing data
|
||||
|
||||
```python
|
||||
# Write to Parquet
|
||||
ds.write_parquet("s3://bucket/output/")
|
||||
|
||||
# Write to CSV
|
||||
ds.write_csv("output/")
|
||||
|
||||
# Write to JSON
|
||||
ds.write_json("output/")
|
||||
```
|
||||
|
||||
## Performance optimization
|
||||
|
||||
### Repartition
|
||||
|
||||
```python
|
||||
# Control parallelism
|
||||
ds = ds.repartition(100) # 100 blocks for 100-core cluster
|
||||
```
|
||||
|
||||
### Batch size tuning
|
||||
|
||||
```python
|
||||
# Larger batches = faster vectorized ops
|
||||
ds.map_batches(process_fn, batch_size=10000) # vs batch_size=100
|
||||
```
|
||||
|
||||
### Streaming execution
|
||||
|
||||
```python
|
||||
# Process data larger than memory
|
||||
ds = ray.data.read_parquet("s3://huge-dataset/")
|
||||
for batch in ds.iter_batches(batch_size=1000):
|
||||
process(batch) # Streamed, not loaded to memory
|
||||
```
|
||||
|
||||
## Common patterns
|
||||
|
||||
### Batch inference
|
||||
|
||||
```python
|
||||
import ray
|
||||
|
||||
# Load model
|
||||
def load_model():
|
||||
# Load once per worker
|
||||
return MyModel()
|
||||
|
||||
# Inference function
|
||||
class BatchInference:
|
||||
def __init__(self):
|
||||
self.model = load_model()
|
||||
|
||||
def __call__(self, batch):
|
||||
predictions = self.model(batch["input"])
|
||||
return {"prediction": predictions}
|
||||
|
||||
# Run distributed inference
|
||||
ds = ray.data.read_parquet("s3://data/")
|
||||
predictions = ds.map_batches(BatchInference, batch_size=32, num_gpus=1)
|
||||
predictions.write_parquet("s3://output/")
|
||||
```
|
||||
|
||||
### Data preprocessing pipeline
|
||||
|
||||
```python
|
||||
# Multi-step pipeline
|
||||
ds = (
|
||||
ray.data.read_parquet("s3://raw/")
|
||||
.map_batches(clean_data)
|
||||
.map_batches(tokenize)
|
||||
.map_batches(augment)
|
||||
.write_parquet("s3://processed/")
|
||||
)
|
||||
```
|
||||
|
||||
## Integration with ML frameworks
|
||||
|
||||
### PyTorch
|
||||
|
||||
```python
|
||||
# Convert to PyTorch
|
||||
torch_ds = ds.to_torch(label_column="label", batch_size=32)
|
||||
|
||||
for batch in torch_ds:
|
||||
# batch is dict with tensors
|
||||
inputs, labels = batch["features"], batch["label"]
|
||||
```
|
||||
|
||||
### TensorFlow
|
||||
|
||||
```python
|
||||
# Convert to TensorFlow
|
||||
tf_ds = ds.to_tf(feature_columns=["image"], label_column="label", batch_size=32)
|
||||
|
||||
for features, labels in tf_ds:
|
||||
# Train model
|
||||
pass
|
||||
```
|
||||
|
||||
## Supported data formats
|
||||
|
||||
| Format | Read | Write | Use Case |
|
||||
|--------|------|-------|----------|
|
||||
| Parquet | ✅ | ✅ | ML data (recommended) |
|
||||
| CSV | ✅ | ✅ | Tabular data |
|
||||
| JSON | ✅ | ✅ | Semi-structured |
|
||||
| Images | ✅ | ❌ | Computer vision |
|
||||
| NumPy | ✅ | ✅ | Arrays |
|
||||
| Pandas | ✅ | ❌ | DataFrames |
|
||||
|
||||
## Performance benchmarks
|
||||
|
||||
**Scaling** (processing 100GB data):
|
||||
- 1 node (16 cores): ~30 minutes
|
||||
- 4 nodes (64 cores): ~8 minutes
|
||||
- 16 nodes (256 cores): ~2 minutes
|
||||
|
||||
**GPU acceleration** (image preprocessing):
|
||||
- CPU only: 1,000 images/sec
|
||||
- 1 GPU: 5,000 images/sec
|
||||
- 4 GPUs: 18,000 images/sec
|
||||
|
||||
## Use cases
|
||||
|
||||
**Production deployments**:
|
||||
- **Pinterest**: Last-mile data processing for model training
|
||||
- **ByteDance**: Scaling offline inference with multi-modal LLMs
|
||||
- **Spotify**: ML platform for batch inference
|
||||
|
||||
## References
|
||||
|
||||
- **[Transformations Guide](references/transformations.md)** - Map, filter, groupby operations
|
||||
- **[Integration Guide](references/integration.md)** - Ray Train, PyTorch, TensorFlow
|
||||
|
||||
## Resources
|
||||
|
||||
- **Docs**: https://docs.ray.io/en/latest/data/data.html
|
||||
- **GitHub**: https://github.com/ray-project/ray ⭐ 36,000+
|
||||
- **Version**: Ray 2.40.0+
|
||||
- **Examples**: https://docs.ray.io/en/latest/data/examples/overview.html
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
# Ray Data Integration Guide
|
||||
|
||||
Integration with Ray Train and ML frameworks.
|
||||
|
||||
## Ray Train integration
|
||||
|
||||
### Basic training with datasets
|
||||
|
||||
```python
|
||||
import ray
|
||||
from ray.train import ScalingConfig
|
||||
from ray.train.torch import TorchTrainer
|
||||
|
||||
# Create datasets
|
||||
train_ds = ray.data.read_parquet("s3://data/train/")
|
||||
val_ds = ray.data.read_parquet("s3://data/val/")
|
||||
|
||||
def train_func(config):
|
||||
# Get dataset shards
|
||||
train_ds = ray.train.get_dataset_shard("train")
|
||||
val_ds = ray.train.get_dataset_shard("val")
|
||||
|
||||
for epoch in range(config["epochs"]):
|
||||
# Iterate over batches
|
||||
for batch in train_ds.iter_batches(batch_size=32):
|
||||
# Train on batch
|
||||
pass
|
||||
|
||||
# Launch training
|
||||
trainer = TorchTrainer(
|
||||
train_func,
|
||||
train_loop_config={"epochs": 10},
|
||||
datasets={"train": train_ds, "val": val_ds},
|
||||
scaling_config=ScalingConfig(num_workers=4, use_gpu=True)
|
||||
)
|
||||
|
||||
result = trainer.fit()
|
||||
```
|
||||
|
||||
## PyTorch integration
|
||||
|
||||
### Convert to PyTorch Dataset
|
||||
|
||||
```python
|
||||
# Option 1: to_torch (recommended)
|
||||
torch_ds = ds.to_torch(
|
||||
label_column="label",
|
||||
batch_size=32,
|
||||
drop_last=True
|
||||
)
|
||||
|
||||
for batch in torch_ds:
|
||||
inputs = batch["features"]
|
||||
labels = batch["label"]
|
||||
# Train model
|
||||
|
||||
# Option 2: iter_torch_batches
|
||||
for batch in ds.iter_torch_batches(batch_size=32):
|
||||
# batch is dict of tensors
|
||||
pass
|
||||
```
|
||||
|
||||
## TensorFlow integration
|
||||
|
||||
```python
|
||||
tf_ds = ds.to_tf(
|
||||
feature_columns=["image", "text"],
|
||||
label_column="label",
|
||||
batch_size=32
|
||||
)
|
||||
|
||||
for features, labels in tf_ds:
|
||||
# Train TensorFlow model
|
||||
pass
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Shard datasets in Ray Train** - Automatic with `get_dataset_shard()`
|
||||
2. **Use streaming** - Don't load entire dataset to memory
|
||||
3. **Preprocess in Ray Data** - Distribute preprocessing across cluster
|
||||
4. **Cache preprocessed data** - Write to Parquet, read in training
|
||||
@@ -0,0 +1,83 @@
|
||||
# Ray Data Transformations
|
||||
|
||||
Complete guide to data transformations in Ray Data.
|
||||
|
||||
## Core operations
|
||||
|
||||
### Map batches (vectorized)
|
||||
|
||||
```python
|
||||
# Recommended for performance
|
||||
def process_batch(batch):
|
||||
# batch is dict of numpy arrays or pandas Series
|
||||
batch["doubled"] = batch["value"] * 2
|
||||
return batch
|
||||
|
||||
ds = ds.map_batches(process_batch, batch_size=1000)
|
||||
```
|
||||
|
||||
**Performance**: 10-100× faster than row-by-row
|
||||
|
||||
### Map (row-by-row)
|
||||
|
||||
```python
|
||||
# Use only when vectorization not possible
|
||||
def process_row(row):
|
||||
row["squared"] = row["value"] ** 2
|
||||
return row
|
||||
|
||||
ds = ds.map(process_row)
|
||||
```
|
||||
|
||||
### Filter
|
||||
|
||||
```python
|
||||
# Remove rows
|
||||
ds = ds.filter(lambda row: row["score"] > 0.5)
|
||||
```
|
||||
|
||||
### Flat map
|
||||
|
||||
```python
|
||||
# One row → multiple rows
|
||||
def expand_row(row):
|
||||
return [{"value": row["value"] + i} for i in range(3)]
|
||||
|
||||
ds = ds.flat_map(expand_row)
|
||||
```
|
||||
|
||||
## GPU-accelerated transforms
|
||||
|
||||
```python
|
||||
def gpu_transform(batch):
|
||||
import torch
|
||||
data = torch.tensor(batch["data"]).cuda()
|
||||
# GPU processing
|
||||
result = data * 2
|
||||
return {"processed": result.cpu().numpy()}
|
||||
|
||||
ds = ds.map_batches(gpu_transform, num_gpus=1, batch_size=64)
|
||||
```
|
||||
|
||||
## Groupby operations
|
||||
|
||||
```python
|
||||
# Group by column
|
||||
grouped = ds.groupby("category")
|
||||
|
||||
# Aggregate
|
||||
result = grouped.count()
|
||||
|
||||
# Custom aggregation
|
||||
result = grouped.map_groups(lambda group: {
|
||||
"sum": group["value"].sum(),
|
||||
"mean": group["value"].mean()
|
||||
})
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Use map_batches over map** - 10-100× faster
|
||||
2. **Tune batch_size** - Larger = faster (balance with memory)
|
||||
3. **Use GPUs for heavy compute** - Image/audio preprocessing
|
||||
4. **Stream large datasets** - Use iter_batches for >memory data
|
||||
@@ -0,0 +1,97 @@
|
||||
# GRPO/RL Training Skill
|
||||
|
||||
**Expert-level guidance for Group Relative Policy Optimization with TRL**
|
||||
|
||||
## 📁 Skill Structure
|
||||
|
||||
```
|
||||
grpo-rl-training/
|
||||
├── SKILL.md # Main skill documentation (READ THIS FIRST)
|
||||
├── README.md # This file
|
||||
├── templates/
|
||||
│ └── basic_grpo_training.py # Production-ready training template
|
||||
└── examples/
|
||||
└── reward_functions_library.py # 20+ reward function examples
|
||||
```
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
1. **Read SKILL.md** - Comprehensive guide with all concepts and patterns
|
||||
2. **Copy `templates/basic_grpo_training.py`** - Start with working code
|
||||
3. **Browse `examples/reward_functions_library.py`** - Pick reward functions for your task
|
||||
4. **Modify for your use case** - Adapt dataset, rewards, and config
|
||||
|
||||
## 💡 What's Inside
|
||||
|
||||
### SKILL.md (Main Documentation)
|
||||
- Core GRPO concepts and algorithm fundamentals
|
||||
- Complete implementation workflow (dataset → rewards → training → deployment)
|
||||
- 10+ reward function examples with code
|
||||
- Hyperparameter tuning guide
|
||||
- Training insights (loss behavior, metrics, debugging)
|
||||
- Troubleshooting guide
|
||||
- Production best practices
|
||||
|
||||
### Templates
|
||||
- **basic_grpo_training.py**: Minimal, production-ready training script
|
||||
- Uses Qwen 2.5 1.5B Instruct
|
||||
- 3 reward functions (format + correctness)
|
||||
- LoRA for efficient training
|
||||
- Fully documented and ready to run
|
||||
|
||||
### Examples
|
||||
- **reward_functions_library.py**: 20+ battle-tested reward functions
|
||||
- Correctness rewards (exact match, fuzzy match, numeric, code execution)
|
||||
- Format rewards (XML, JSON, strict/soft)
|
||||
- Length rewards (ideal length, min/max)
|
||||
- Style rewards (reasoning quality, citations, repetition penalty)
|
||||
- Combined rewards (multi-objective optimization)
|
||||
- Preset collections for common tasks
|
||||
|
||||
## 📖 Usage for Agents
|
||||
|
||||
When this skill is loaded in your agent's context:
|
||||
|
||||
1. **Always read SKILL.md first** before implementing
|
||||
2. **Start simple** - Use length-based reward to validate setup
|
||||
3. **Build incrementally** - Add one reward function at a time
|
||||
4. **Reference examples** - Copy patterns from reward_functions_library.py
|
||||
5. **Monitor training** - Watch reward metrics (not loss!)
|
||||
|
||||
## 🎯 Common Use Cases
|
||||
|
||||
| Task Type | Recommended Rewards | Template |
|
||||
|-----------|---------------------|----------|
|
||||
| Math reasoning | `MATH_REASONING_REWARDS` preset | basic_grpo_training.py |
|
||||
| Code generation | `CODE_GENERATION_REWARDS` preset | Modify dataset in template |
|
||||
| Summarization | `SUMMARIZATION_REWARDS` preset | Adjust prompts + rewards |
|
||||
| Q&A | `QA_REWARDS` preset | Use fuzzy match + citations |
|
||||
|
||||
## ⚠️ Critical Reminders
|
||||
|
||||
- **Loss goes UP during training** - This is normal (it's KL divergence)
|
||||
- **Use 3-5 reward functions** - Single rewards often fail
|
||||
- **Test rewards before training** - Debug each function independently
|
||||
- **Monitor reward_std** - Should stay > 0.1 (avoid mode collapse)
|
||||
- **Start with num_generations=4-8** - Scale up if GPU allows
|
||||
|
||||
## 🔗 External Resources
|
||||
|
||||
- [TRL Documentation](https://huggingface.co/docs/trl)
|
||||
- [DeepSeek R1 Paper](https://arxiv.org/abs/2501.12948)
|
||||
- [Open R1 Implementation](https://github.com/huggingface/open-r1)
|
||||
- [Unsloth (2-3x faster)](https://docs.unsloth.ai/)
|
||||
|
||||
## 📝 Version
|
||||
|
||||
**v1.0.0** - Initial release (January 2025)
|
||||
|
||||
## 👨💻 Maintained By
|
||||
|
||||
Orchestra Research
|
||||
For questions or improvements, see https://orchestra.com
|
||||
|
||||
---
|
||||
|
||||
**License:** MIT
|
||||
**Last Updated:** January 2025
|
||||
@@ -0,0 +1,572 @@
|
||||
---
|
||||
name: grpo-rl-training
|
||||
description: Expert guidance for GRPO/RL fine-tuning with TRL for reasoning and task-specific model training
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Post-Training, Reinforcement Learning, GRPO, TRL, RLHF, Reward Modeling, Reasoning, DPO, PPO, Structured Output]
|
||||
dependencies: [transformers>=4.47.0, trl>=0.14.0, datasets>=3.2.0, peft>=0.14.0, torch]
|
||||
---
|
||||
|
||||
# GRPO/RL Training with TRL
|
||||
|
||||
Expert-level guidance for implementing Group Relative Policy Optimization (GRPO) using the Transformer Reinforcement Learning (TRL) library. This skill provides battle-tested patterns, critical insights, and production-ready workflows for fine-tuning language models with custom reward functions.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use GRPO training when you need to:
|
||||
- **Enforce specific output formats** (e.g., XML tags, JSON, structured reasoning)
|
||||
- **Teach verifiable tasks** with objective correctness metrics (math, coding, fact-checking)
|
||||
- **Improve reasoning capabilities** by rewarding chain-of-thought patterns
|
||||
- **Align models to domain-specific behaviors** without labeled preference data
|
||||
- **Optimize for multiple objectives** simultaneously (format + correctness + style)
|
||||
|
||||
**Do NOT use GRPO for:**
|
||||
- Simple supervised fine-tuning tasks (use SFT instead)
|
||||
- Tasks without clear reward signals
|
||||
- When you already have high-quality preference pairs (use DPO/PPO instead)
|
||||
|
||||
---
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### 1. GRPO Algorithm Fundamentals
|
||||
|
||||
**Key Mechanism:**
|
||||
- Generates **multiple completions** for each prompt (group size: 4-16)
|
||||
- Compares completions within each group using reward functions
|
||||
- Updates policy to favor higher-rewarded responses relative to the group
|
||||
|
||||
**Critical Difference from PPO:**
|
||||
- No separate reward model needed
|
||||
- More sample-efficient (learns from within-group comparisons)
|
||||
- Simpler to implement and debug
|
||||
|
||||
**Mathematical Intuition:**
|
||||
```
|
||||
For each prompt p:
|
||||
1. Generate N completions: {c₁, c₂, ..., cₙ}
|
||||
2. Compute rewards: {r₁, r₂, ..., rₙ}
|
||||
3. Learn to increase probability of high-reward completions
|
||||
relative to low-reward ones in the same group
|
||||
```
|
||||
|
||||
### 2. Reward Function Design Philosophy
|
||||
|
||||
**Golden Rules:**
|
||||
1. **Compose multiple reward functions** - Each handles one aspect (format, correctness, style)
|
||||
2. **Scale rewards appropriately** - Higher weight = stronger signal
|
||||
3. **Use incremental rewards** - Partial credit for partial compliance
|
||||
4. **Test rewards independently** - Debug each reward function in isolation
|
||||
|
||||
**Reward Function Types:**
|
||||
|
||||
| Type | Use Case | Example Weight |
|
||||
|------|----------|----------------|
|
||||
| **Correctness** | Verifiable tasks (math, code) | 2.0 (highest) |
|
||||
| **Format** | Strict structure enforcement | 0.5-1.0 |
|
||||
| **Length** | Encourage verbosity/conciseness | 0.1-0.5 |
|
||||
| **Style** | Penalize unwanted patterns | -0.5 to 0.5 |
|
||||
|
||||
---
|
||||
|
||||
## Implementation Workflow
|
||||
|
||||
### Step 1: Dataset Preparation
|
||||
|
||||
**Critical Requirements:**
|
||||
- Prompts in chat format (list of dicts with 'role' and 'content')
|
||||
- Include system prompts to set expectations
|
||||
- For verifiable tasks, include ground truth answers as additional columns
|
||||
|
||||
**Example Structure:**
|
||||
```python
|
||||
from datasets import load_dataset, Dataset
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
Respond in the following format:
|
||||
<reasoning>
|
||||
[Your step-by-step thinking]
|
||||
</reasoning>
|
||||
<answer>
|
||||
[Final answer]
|
||||
</answer>
|
||||
"""
|
||||
|
||||
def prepare_dataset(raw_data):
|
||||
"""
|
||||
Transform raw data into GRPO-compatible format.
|
||||
|
||||
Returns: Dataset with columns:
|
||||
- 'prompt': List[Dict] with role/content (system + user messages)
|
||||
- 'answer': str (ground truth, optional but recommended)
|
||||
"""
|
||||
return raw_data.map(lambda x: {
|
||||
'prompt': [
|
||||
{'role': 'system', 'content': SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': x['question']}
|
||||
],
|
||||
'answer': extract_answer(x['raw_answer'])
|
||||
})
|
||||
```
|
||||
|
||||
**Pro Tips:**
|
||||
- Use one-shot or few-shot examples in system prompt for complex formats
|
||||
- Keep prompts concise (max_prompt_length: 256-512 tokens)
|
||||
- Validate data quality before training (garbage in = garbage out)
|
||||
|
||||
### Step 2: Reward Function Implementation
|
||||
|
||||
**Template Structure:**
|
||||
```python
|
||||
def reward_function_name(
|
||||
prompts, # List[List[Dict]]: Original prompts
|
||||
completions, # List[List[Dict]]: Model generations
|
||||
answer=None, # Optional: Ground truth from dataset
|
||||
**kwargs # Additional dataset columns
|
||||
) -> list[float]:
|
||||
"""
|
||||
Evaluate completions and return rewards.
|
||||
|
||||
Returns: List of floats (one per completion)
|
||||
"""
|
||||
# Extract completion text
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
|
||||
# Compute rewards
|
||||
rewards = []
|
||||
for response in responses:
|
||||
score = compute_score(response)
|
||||
rewards.append(score)
|
||||
|
||||
return rewards
|
||||
```
|
||||
|
||||
**Example 1: Correctness Reward (Math/Coding)**
|
||||
```python
|
||||
def correctness_reward(prompts, completions, answer, **kwargs):
|
||||
"""Reward correct answers with high score."""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
extracted = [extract_final_answer(r) for r in responses]
|
||||
return [2.0 if ans == gt else 0.0
|
||||
for ans, gt in zip(extracted, answer)]
|
||||
```
|
||||
|
||||
**Example 2: Format Reward (Structured Output)**
|
||||
```python
|
||||
import re
|
||||
|
||||
def format_reward(completions, **kwargs):
|
||||
"""Reward XML-like structured format."""
|
||||
pattern = r'<reasoning>.*?</reasoning>\s*<answer>.*?</answer>'
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
return [1.0 if re.search(pattern, r, re.DOTALL) else 0.0
|
||||
for r in responses]
|
||||
```
|
||||
|
||||
**Example 3: Incremental Format Reward (Partial Credit)**
|
||||
```python
|
||||
def incremental_format_reward(completions, **kwargs):
|
||||
"""Award partial credit for format compliance."""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
rewards = []
|
||||
|
||||
for r in responses:
|
||||
score = 0.0
|
||||
if '<reasoning>' in r:
|
||||
score += 0.25
|
||||
if '</reasoning>' in r:
|
||||
score += 0.25
|
||||
if '<answer>' in r:
|
||||
score += 0.25
|
||||
if '</answer>' in r:
|
||||
score += 0.25
|
||||
# Penalize extra text after closing tag
|
||||
if r.count('</answer>') == 1:
|
||||
extra_text = r.split('</answer>')[-1].strip()
|
||||
score -= len(extra_text) * 0.001
|
||||
rewards.append(score)
|
||||
|
||||
return rewards
|
||||
```
|
||||
|
||||
**Critical Insight:**
|
||||
Combine 3-5 reward functions for robust training. Order matters less than diversity of signals.
|
||||
|
||||
### Step 3: Training Configuration
|
||||
|
||||
**Memory-Optimized Config (Small GPU)**
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(
|
||||
output_dir="outputs/grpo-model",
|
||||
|
||||
# Learning rate
|
||||
learning_rate=5e-6, # Lower = more stable
|
||||
adam_beta1=0.9,
|
||||
adam_beta2=0.99,
|
||||
weight_decay=0.1,
|
||||
warmup_ratio=0.1,
|
||||
lr_scheduler_type='cosine',
|
||||
|
||||
# Batch settings
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=4, # Effective batch = 4
|
||||
|
||||
# GRPO-specific
|
||||
num_generations=8, # Group size: 8-16 recommended
|
||||
max_prompt_length=256,
|
||||
max_completion_length=512,
|
||||
|
||||
# Training duration
|
||||
num_train_epochs=1,
|
||||
max_steps=None, # Or set fixed steps (e.g., 500)
|
||||
|
||||
# Optimization
|
||||
bf16=True, # Faster on A100/H100
|
||||
optim="adamw_8bit", # Memory-efficient optimizer
|
||||
max_grad_norm=0.1,
|
||||
|
||||
# Logging
|
||||
logging_steps=1,
|
||||
save_steps=100,
|
||||
report_to="wandb", # Or "none" for no logging
|
||||
)
|
||||
```
|
||||
|
||||
**High-Performance Config (Large GPU)**
|
||||
```python
|
||||
training_args = GRPOConfig(
|
||||
output_dir="outputs/grpo-model",
|
||||
learning_rate=1e-5,
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=2,
|
||||
num_generations=16, # Larger groups = better signal
|
||||
max_prompt_length=512,
|
||||
max_completion_length=1024,
|
||||
num_train_epochs=1,
|
||||
bf16=True,
|
||||
use_vllm=True, # Fast generation with vLLM
|
||||
logging_steps=10,
|
||||
)
|
||||
```
|
||||
|
||||
**Critical Hyperparameters:**
|
||||
|
||||
| Parameter | Impact | Tuning Advice |
|
||||
|-----------|--------|---------------|
|
||||
| `num_generations` | Group size for comparison | Start with 8, increase to 16 if GPU allows |
|
||||
| `learning_rate` | Convergence speed/stability | 5e-6 (safe), 1e-5 (faster, riskier) |
|
||||
| `max_completion_length` | Output verbosity | Match your task (512 for reasoning, 256 for short answers) |
|
||||
| `gradient_accumulation_steps` | Effective batch size | Increase if GPU memory limited |
|
||||
|
||||
### Step 4: Model Setup and Training
|
||||
|
||||
**Standard Setup (Transformers)**
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import LoraConfig
|
||||
from trl import GRPOTrainer
|
||||
|
||||
# Load model
|
||||
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2", # 2-3x faster
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# Optional: LoRA for parameter-efficient training
|
||||
peft_config = LoraConfig(
|
||||
r=16, # Rank (higher = more capacity)
|
||||
lora_alpha=32, # Scaling factor (typically 2*r)
|
||||
target_modules=[
|
||||
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj"
|
||||
],
|
||||
task_type="CAUSAL_LM",
|
||||
lora_dropout=0.05,
|
||||
)
|
||||
|
||||
# Initialize trainer
|
||||
trainer = GRPOTrainer(
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
reward_funcs=[
|
||||
incremental_format_reward,
|
||||
format_reward,
|
||||
correctness_reward,
|
||||
],
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
peft_config=peft_config, # Remove for full fine-tuning
|
||||
)
|
||||
|
||||
# Train
|
||||
trainer.train()
|
||||
|
||||
# Save
|
||||
trainer.save_model("final_model")
|
||||
```
|
||||
|
||||
**Unsloth Setup (2-3x Faster)**
|
||||
```python
|
||||
from unsloth import FastLanguageModel
|
||||
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="google/gemma-3-1b-it",
|
||||
max_seq_length=1024,
|
||||
load_in_4bit=True,
|
||||
fast_inference=True,
|
||||
max_lora_rank=32,
|
||||
)
|
||||
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=32,
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj"],
|
||||
lora_alpha=32,
|
||||
use_gradient_checkpointing="unsloth",
|
||||
)
|
||||
|
||||
# Rest is identical to standard setup
|
||||
trainer = GRPOTrainer(model=model, ...)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Critical Training Insights
|
||||
|
||||
### 1. Loss Behavior (EXPECTED PATTERN)
|
||||
- **Loss starts near 0 and INCREASES during training**
|
||||
- This is CORRECT - loss measures KL divergence from initial policy
|
||||
- Model is learning (diverging from original behavior to optimize rewards)
|
||||
- Monitor reward metrics instead of loss for progress
|
||||
|
||||
### 2. Reward Tracking
|
||||
Key metrics to watch:
|
||||
- `reward`: Average across all completions
|
||||
- `reward_std`: Diversity within groups (should remain > 0)
|
||||
- `kl`: KL divergence from reference (should grow moderately)
|
||||
|
||||
**Healthy Training Pattern:**
|
||||
```
|
||||
Step Reward Reward_Std KL
|
||||
100 0.5 0.3 0.02
|
||||
200 0.8 0.25 0.05
|
||||
300 1.2 0.2 0.08 ← Good progression
|
||||
400 1.5 0.15 0.12
|
||||
```
|
||||
|
||||
**Warning Signs:**
|
||||
- Reward std → 0 (model collapsing to single response)
|
||||
- KL exploding (> 0.5) (diverging too much, reduce LR)
|
||||
- Reward stuck (reward functions too harsh or model capacity issue)
|
||||
|
||||
### 3. Common Pitfalls and Solutions
|
||||
|
||||
| Problem | Symptom | Solution |
|
||||
|---------|---------|----------|
|
||||
| **Mode collapse** | All completions identical | Increase `num_generations`, add diversity penalty |
|
||||
| **No learning** | Flat rewards | Check reward function logic, increase LR |
|
||||
| **OOM errors** | GPU memory exceeded | Reduce `num_generations`, enable gradient checkpointing |
|
||||
| **Slow training** | < 1 it/s | Enable `use_vllm=True`, use Unsloth, reduce seq length |
|
||||
| **Format ignored** | Model doesn't follow structure | Increase format reward weight, add incremental rewards |
|
||||
|
||||
---
|
||||
|
||||
## Advanced Patterns
|
||||
|
||||
### 1. Multi-Stage Training
|
||||
For complex tasks, train in stages:
|
||||
|
||||
```python
|
||||
# Stage 1: Format compliance (epochs=1)
|
||||
trainer_stage1 = GRPOTrainer(
|
||||
model=model,
|
||||
reward_funcs=[incremental_format_reward, format_reward],
|
||||
...
|
||||
)
|
||||
trainer_stage1.train()
|
||||
|
||||
# Stage 2: Correctness (epochs=1)
|
||||
trainer_stage2 = GRPOTrainer(
|
||||
model=model,
|
||||
reward_funcs=[format_reward, correctness_reward],
|
||||
...
|
||||
)
|
||||
trainer_stage2.train()
|
||||
```
|
||||
|
||||
### 2. Adaptive Reward Scaling
|
||||
```python
|
||||
class AdaptiveReward:
|
||||
def __init__(self, base_reward_func, initial_weight=1.0):
|
||||
self.func = base_reward_func
|
||||
self.weight = initial_weight
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
rewards = self.func(*args, **kwargs)
|
||||
return [r * self.weight for r in rewards]
|
||||
|
||||
def adjust_weight(self, success_rate):
|
||||
"""Increase weight if model struggling, decrease if succeeding."""
|
||||
if success_rate < 0.3:
|
||||
self.weight *= 1.2
|
||||
elif success_rate > 0.8:
|
||||
self.weight *= 0.9
|
||||
```
|
||||
|
||||
### 3. Custom Dataset Integration
|
||||
```python
|
||||
def load_custom_knowledge_base(csv_path):
|
||||
"""Example: School communication platform docs."""
|
||||
import pandas as pd
|
||||
df = pd.read_csv(csv_path)
|
||||
|
||||
dataset = Dataset.from_pandas(df).map(lambda x: {
|
||||
'prompt': [
|
||||
{'role': 'system', 'content': CUSTOM_SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': x['question']}
|
||||
],
|
||||
'answer': x['expert_answer']
|
||||
})
|
||||
return dataset
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Deployment and Inference
|
||||
|
||||
### Save and Merge LoRA
|
||||
```python
|
||||
# Merge LoRA adapters into base model
|
||||
if hasattr(trainer.model, 'merge_and_unload'):
|
||||
merged_model = trainer.model.merge_and_unload()
|
||||
merged_model.save_pretrained("production_model")
|
||||
tokenizer.save_pretrained("production_model")
|
||||
```
|
||||
|
||||
### Inference Example
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
generator = pipeline(
|
||||
"text-generation",
|
||||
model="production_model",
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
|
||||
result = generator(
|
||||
[
|
||||
{'role': 'system', 'content': SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': "What is 15 + 27?"}
|
||||
],
|
||||
max_new_tokens=256,
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
top_p=0.9
|
||||
)
|
||||
print(result[0]['generated_text'])
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Best Practices Checklist
|
||||
|
||||
**Before Training:**
|
||||
- [ ] Validate dataset format (prompts as List[Dict])
|
||||
- [ ] Test reward functions on sample data
|
||||
- [ ] Calculate expected max_prompt_length from data
|
||||
- [ ] Choose appropriate num_generations based on GPU memory
|
||||
- [ ] Set up logging (wandb recommended)
|
||||
|
||||
**During Training:**
|
||||
- [ ] Monitor reward progression (should increase)
|
||||
- [ ] Check reward_std (should stay > 0.1)
|
||||
- [ ] Watch for OOM errors (reduce batch size if needed)
|
||||
- [ ] Sample generations every 50-100 steps
|
||||
- [ ] Validate format compliance on holdout set
|
||||
|
||||
**After Training:**
|
||||
- [ ] Merge LoRA weights if using PEFT
|
||||
- [ ] Test on diverse prompts
|
||||
- [ ] Compare to baseline model
|
||||
- [ ] Document reward weights and hyperparameters
|
||||
- [ ] Save reproducibility config
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting Guide
|
||||
|
||||
### Debugging Workflow
|
||||
1. **Isolate reward functions** - Test each independently
|
||||
2. **Check data distribution** - Ensure diversity in prompts
|
||||
3. **Reduce complexity** - Start with single reward, add gradually
|
||||
4. **Monitor generations** - Print samples every N steps
|
||||
5. **Validate extraction logic** - Ensure answer parsing works
|
||||
|
||||
### Quick Fixes
|
||||
```python
|
||||
# Debug reward function
|
||||
def debug_reward(completions, **kwargs):
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
for i, r in enumerate(responses[:2]): # Print first 2
|
||||
print(f"Response {i}: {r[:200]}...")
|
||||
return [1.0] * len(responses) # Dummy rewards
|
||||
|
||||
# Test without training
|
||||
trainer = GRPOTrainer(..., reward_funcs=[debug_reward])
|
||||
trainer.generate_completions(dataset[:1]) # Generate without updating
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## References and Resources
|
||||
|
||||
**Official Documentation:**
|
||||
- TRL GRPO Trainer: https://huggingface.co/docs/trl/grpo_trainer
|
||||
- DeepSeek R1 Paper: https://arxiv.org/abs/2501.12948
|
||||
- Unsloth Docs: https://docs.unsloth.ai/
|
||||
|
||||
**Example Repositories:**
|
||||
- Open R1 Implementation: https://github.com/huggingface/open-r1
|
||||
- TRL Examples: https://github.com/huggingface/trl/tree/main/examples
|
||||
|
||||
**Recommended Reading:**
|
||||
- Progressive Disclosure Pattern for agent instructions
|
||||
- Reward shaping in RL (Ng et al.)
|
||||
- LoRA paper (Hu et al., 2021)
|
||||
|
||||
---
|
||||
|
||||
## Usage Instructions for Agents
|
||||
|
||||
When this skill is loaded:
|
||||
|
||||
1. **Read this entire file** before implementing GRPO training
|
||||
2. **Start with the simplest reward function** (e.g., length-based) to validate setup
|
||||
3. **Use the templates** in `templates/` directory as starting points
|
||||
4. **Reference examples** in `examples/` for task-specific implementations
|
||||
5. **Follow the workflow** sequentially (don't skip steps)
|
||||
6. **Debug incrementally** - add one reward function at a time
|
||||
|
||||
**Critical Reminders:**
|
||||
- Always use multiple reward functions (3-5 is optimal)
|
||||
- Monitor reward metrics, not loss
|
||||
- Test reward functions before training
|
||||
- Start small (num_generations=4), scale up gradually
|
||||
- Save checkpoints frequently (every 100 steps)
|
||||
|
||||
This skill is designed for **expert-level implementation**. Beginners should start with supervised fine-tuning before attempting GRPO.
|
||||
|
||||
|
||||
|
||||
+393
@@ -0,0 +1,393 @@
|
||||
"""
|
||||
GRPO Reward Functions Library
|
||||
===============================
|
||||
|
||||
A collection of battle-tested reward functions for common GRPO training scenarios.
|
||||
Copy and adapt these for your specific use case.
|
||||
|
||||
Categories:
|
||||
- Correctness rewards (verifiable tasks)
|
||||
- Format rewards (structured output)
|
||||
- Length rewards (verbosity control)
|
||||
- Style rewards (quality and tone)
|
||||
- Combined rewards (multi-objective)
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import List, Any
|
||||
|
||||
# ==================== CORRECTNESS REWARDS ====================
|
||||
|
||||
def exact_match_reward(prompts, completions, answer, **kwargs) -> List[float]:
|
||||
"""
|
||||
Binary reward for exact answer match.
|
||||
Use for: Math problems, factual Q&A, code output
|
||||
|
||||
Weight: 2.0 (highest priority)
|
||||
"""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
extracted = [extract_answer(r) for r in responses]
|
||||
return [2.0 if ans.strip() == gt.strip() else 0.0
|
||||
for ans, gt in zip(extracted, answer)]
|
||||
|
||||
def fuzzy_match_reward(prompts, completions, answer, **kwargs) -> List[float]:
|
||||
"""
|
||||
Partial credit for similar answers.
|
||||
Use for: Open-ended answers, summaries
|
||||
|
||||
Weight: 1.0
|
||||
"""
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
extracted = [extract_answer(r) for r in responses]
|
||||
|
||||
rewards = []
|
||||
for ans, gt in zip(extracted, answer):
|
||||
similarity = SequenceMatcher(None, ans.lower(), gt.lower()).ratio()
|
||||
rewards.append(similarity)
|
||||
|
||||
return rewards
|
||||
|
||||
def numeric_correctness_reward(prompts, completions, answer, tolerance=0.01, **kwargs) -> List[float]:
|
||||
"""
|
||||
Reward numeric answers within tolerance.
|
||||
Use for: Math, physics, engineering problems
|
||||
|
||||
Weight: 2.0
|
||||
"""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
extracted = [extract_answer(r) for r in responses]
|
||||
|
||||
rewards = []
|
||||
for ans, gt in zip(extracted, answer):
|
||||
try:
|
||||
ans_num = float(ans.replace(',', ''))
|
||||
gt_num = float(gt.replace(',', ''))
|
||||
if abs(ans_num - gt_num) / max(abs(gt_num), 1e-8) <= tolerance:
|
||||
rewards.append(2.0)
|
||||
else:
|
||||
rewards.append(0.0)
|
||||
except:
|
||||
rewards.append(0.0)
|
||||
|
||||
return rewards
|
||||
|
||||
def code_execution_reward(prompts, completions, test_cases, **kwargs) -> List[float]:
|
||||
"""
|
||||
Execute code and verify against test cases.
|
||||
Use for: Code generation tasks
|
||||
|
||||
Weight: 2.0
|
||||
"""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
extracted_code = [extract_code_block(r) for r in responses]
|
||||
|
||||
rewards = []
|
||||
for code in extracted_code:
|
||||
try:
|
||||
# Execute code (sandboxed!)
|
||||
passed = run_test_cases(code, test_cases)
|
||||
rewards.append(2.0 if passed else 0.0)
|
||||
except:
|
||||
rewards.append(0.0)
|
||||
|
||||
return rewards
|
||||
|
||||
# ==================== FORMAT REWARDS ====================
|
||||
|
||||
def strict_xml_format_reward(completions, **kwargs) -> List[float]:
|
||||
"""
|
||||
Strict XML format: exact newlines and spacing.
|
||||
Use for: When format must be EXACTLY specified
|
||||
|
||||
Weight: 0.5
|
||||
"""
|
||||
pattern = r'^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$'
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
matches = [re.match(pattern, r, re.DOTALL) for r in responses]
|
||||
return [0.5 if match else 0.0 for match in matches]
|
||||
|
||||
def soft_xml_format_reward(completions, **kwargs) -> List[float]:
|
||||
"""
|
||||
Relaxed XML format: allows whitespace variations.
|
||||
Use for: When structure matters more than exact spacing
|
||||
|
||||
Weight: 0.5
|
||||
"""
|
||||
pattern = r'<reasoning>.*?</reasoning>\s*<answer>.*?</answer>'
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
matches = [re.search(pattern, r, re.DOTALL) for r in responses]
|
||||
return [0.5 if match else 0.0 for match in matches]
|
||||
|
||||
def json_format_reward(completions, **kwargs) -> List[float]:
|
||||
"""
|
||||
Reward valid JSON output.
|
||||
Use for: Structured data extraction, API responses
|
||||
|
||||
Weight: 0.5
|
||||
"""
|
||||
import json
|
||||
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
rewards = []
|
||||
|
||||
for r in responses:
|
||||
try:
|
||||
json.loads(r)
|
||||
rewards.append(0.5)
|
||||
except:
|
||||
rewards.append(0.0)
|
||||
|
||||
return rewards
|
||||
|
||||
def incremental_format_reward(completions, tags=['reasoning', 'answer'], **kwargs) -> List[float]:
|
||||
"""
|
||||
Partial credit for each required tag.
|
||||
Use for: Training models to gradually learn format
|
||||
|
||||
Weight: sum(0.125 * num_tags * 2) = up to 0.5 for 2 tags
|
||||
"""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
rewards = []
|
||||
|
||||
for r in responses:
|
||||
score = 0.0
|
||||
for tag in tags:
|
||||
if f'<{tag}>' in r:
|
||||
score += 0.125
|
||||
if f'</{tag}>' in r:
|
||||
score += 0.125
|
||||
|
||||
# Penalize extra content after final closing tag
|
||||
if f'</{tags[-1]}>' in r:
|
||||
extra = r.split(f'</{tags[-1]}>')[-1].strip()
|
||||
score -= len(extra) * 0.001
|
||||
|
||||
rewards.append(score)
|
||||
|
||||
return rewards
|
||||
|
||||
# ==================== LENGTH REWARDS ====================
|
||||
|
||||
def ideal_length_reward(completions, ideal_tokens=100, **kwargs) -> List[float]:
|
||||
"""
|
||||
Reward responses near ideal length.
|
||||
Use for: Controlling verbosity
|
||||
|
||||
Weight: 0.3
|
||||
"""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
rewards = []
|
||||
|
||||
for r in responses:
|
||||
length = len(r.split())
|
||||
distance = abs(length - ideal_tokens)
|
||||
# Gaussian-like reward peaking at ideal length
|
||||
reward = 0.3 * max(0, 1 - distance / ideal_tokens)
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
|
||||
def min_length_reward(completions, min_tokens=50, **kwargs) -> List[float]:
|
||||
"""
|
||||
Penalize responses that are too short.
|
||||
Use for: Ensuring detailed explanations
|
||||
|
||||
Weight: 0.2
|
||||
"""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
rewards = []
|
||||
|
||||
for r in responses:
|
||||
length = len(r.split())
|
||||
reward = 0.2 if length >= min_tokens else -0.2
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
|
||||
def max_length_penalty(completions, max_tokens=500, **kwargs) -> List[float]:
|
||||
"""
|
||||
Penalize excessively long responses.
|
||||
Use for: Preventing rambling
|
||||
|
||||
Weight: -0.3 when violated
|
||||
"""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
rewards = []
|
||||
|
||||
for r in responses:
|
||||
length = len(r.split())
|
||||
reward = -0.3 if length > max_tokens else 0.0
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
|
||||
# ==================== STYLE REWARDS ====================
|
||||
|
||||
def reasoning_quality_reward(completions, **kwargs) -> List[float]:
|
||||
"""
|
||||
Reward detailed reasoning with logical connectors.
|
||||
Use for: Improving chain-of-thought quality
|
||||
|
||||
Weight: 0.3
|
||||
"""
|
||||
logical_words = ['therefore', 'thus', 'because', 'since', 'consequently',
|
||||
'first', 'second', 'next', 'finally', 'however']
|
||||
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
rewards = []
|
||||
|
||||
for r in responses:
|
||||
reasoning = extract_xml_tag(r, 'reasoning').lower()
|
||||
# Count logical connectors
|
||||
count = sum(1 for word in logical_words if word in reasoning)
|
||||
# Normalize by length
|
||||
score = min(0.3, count * 0.05)
|
||||
rewards.append(score)
|
||||
|
||||
return rewards
|
||||
|
||||
def citation_reward(completions, **kwargs) -> List[float]:
|
||||
"""
|
||||
Reward responses with citations or references.
|
||||
Use for: Research tasks, fact-checking
|
||||
|
||||
Weight: 0.2
|
||||
"""
|
||||
citation_patterns = [
|
||||
r'\[\d+\]', # [1], [2]
|
||||
r'\([A-Z][a-z]+,?\s+\d{4}\)', # (Smith, 2020)
|
||||
r'according to',
|
||||
r'as stated in',
|
||||
]
|
||||
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
rewards = []
|
||||
|
||||
for r in responses:
|
||||
has_citation = any(re.search(pattern, r) for pattern in citation_patterns)
|
||||
rewards.append(0.2 if has_citation else 0.0)
|
||||
|
||||
return rewards
|
||||
|
||||
def no_repetition_penalty(completions, **kwargs) -> List[float]:
|
||||
"""
|
||||
Penalize repetitive text (same phrase repeated).
|
||||
Use for: Improving output diversity
|
||||
|
||||
Weight: -0.3 when repetitive
|
||||
"""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
rewards = []
|
||||
|
||||
for r in responses:
|
||||
words = r.lower().split()
|
||||
# Check for repeated trigrams
|
||||
trigrams = [' '.join(words[i:i+3]) for i in range(len(words)-2)]
|
||||
unique_ratio = len(set(trigrams)) / max(len(trigrams), 1)
|
||||
|
||||
reward = -0.3 if unique_ratio < 0.7 else 0.0
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
|
||||
# ==================== COMBINED REWARDS ====================
|
||||
|
||||
def math_problem_reward(prompts, completions, answer, **kwargs) -> List[float]:
|
||||
"""
|
||||
Combined reward for math problems: format + correctness.
|
||||
Automatically balances multiple objectives.
|
||||
|
||||
Weight: 2.5 total
|
||||
"""
|
||||
format_rewards = soft_xml_format_reward(completions)
|
||||
correctness_rewards = exact_match_reward(prompts, completions, answer)
|
||||
|
||||
return [f + c for f, c in zip(format_rewards, correctness_rewards)]
|
||||
|
||||
def code_generation_reward(prompts, completions, test_cases, **kwargs) -> List[float]:
|
||||
"""
|
||||
Combined reward for code: format + execution + style.
|
||||
|
||||
Weight: 2.7 total
|
||||
"""
|
||||
code_format_rewards = code_block_format_reward(completions)
|
||||
execution_rewards = code_execution_reward(prompts, completions, test_cases)
|
||||
no_error_rewards = no_syntax_error_reward(completions)
|
||||
|
||||
return [f + e + s for f, e, s in zip(code_format_rewards, execution_rewards, no_error_rewards)]
|
||||
|
||||
# ==================== HELPER FUNCTIONS ====================
|
||||
|
||||
def extract_answer(text: str) -> str:
|
||||
"""Extract content from <answer> tags."""
|
||||
return extract_xml_tag(text, 'answer')
|
||||
|
||||
def extract_xml_tag(text: str, tag: str) -> str:
|
||||
"""Generic XML tag extraction."""
|
||||
pattern = f'<{tag}>(.*?)</{tag}>'
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
return match.group(1).strip() if match else ""
|
||||
|
||||
def extract_code_block(text: str) -> str:
|
||||
"""Extract code from markdown code blocks."""
|
||||
pattern = r'```(?:python)?\n(.*?)\n```'
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
return match.group(1) if match else ""
|
||||
|
||||
def run_test_cases(code: str, test_cases: List[tuple]) -> bool:
|
||||
"""
|
||||
Execute code with test cases (MUST be sandboxed in production!).
|
||||
|
||||
Args:
|
||||
code: Python code string
|
||||
test_cases: List of (input, expected_output) tuples
|
||||
|
||||
Returns:
|
||||
True if all tests pass
|
||||
"""
|
||||
# WARNING: This is a simplified example
|
||||
# In production, use proper sandboxing (e.g., docker, pypy sandbox)
|
||||
try:
|
||||
exec_globals = {}
|
||||
exec(code, exec_globals)
|
||||
|
||||
for input_val, expected in test_cases:
|
||||
result = exec_globals['solution'](input_val)
|
||||
if result != expected:
|
||||
return False
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
# ==================== REWARD FUNCTION PRESETS ====================
|
||||
|
||||
# Preset for math/reasoning tasks
|
||||
MATH_REASONING_REWARDS = [
|
||||
incremental_format_reward,
|
||||
soft_xml_format_reward,
|
||||
exact_match_reward,
|
||||
reasoning_quality_reward,
|
||||
]
|
||||
|
||||
# Preset for code generation
|
||||
CODE_GENERATION_REWARDS = [
|
||||
code_block_format_reward,
|
||||
code_execution_reward,
|
||||
no_syntax_error_reward,
|
||||
]
|
||||
|
||||
# Preset for summarization
|
||||
SUMMARIZATION_REWARDS = [
|
||||
ideal_length_reward,
|
||||
fuzzy_match_reward,
|
||||
no_repetition_penalty,
|
||||
]
|
||||
|
||||
# Preset for Q&A
|
||||
QA_REWARDS = [
|
||||
exact_match_reward,
|
||||
min_length_reward,
|
||||
citation_reward,
|
||||
]
|
||||
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
Basic GRPO Training Template
|
||||
=============================
|
||||
|
||||
A minimal, production-ready template for GRPO training with TRL.
|
||||
Adapt this for your specific task by modifying:
|
||||
1. Dataset loading (get_dataset function)
|
||||
2. Reward functions (reward_*_func)
|
||||
3. System prompt (SYSTEM_PROMPT)
|
||||
4. Hyperparameters (GRPOConfig)
|
||||
"""
|
||||
|
||||
import torch
|
||||
import re
|
||||
from datasets import load_dataset, Dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import LoraConfig
|
||||
from trl import GRPOTrainer, GRPOConfig
|
||||
|
||||
# ==================== CONFIGURATION ====================
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
OUTPUT_DIR = "outputs/grpo-model"
|
||||
MAX_PROMPT_LENGTH = 256
|
||||
MAX_COMPLETION_LENGTH = 512
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
Respond in the following format:
|
||||
<reasoning>
|
||||
[Your step-by-step thinking]
|
||||
</reasoning>
|
||||
<answer>
|
||||
[Final answer]
|
||||
</answer>
|
||||
"""
|
||||
|
||||
# ==================== DATASET ====================
|
||||
|
||||
def get_dataset(split="train"):
|
||||
"""
|
||||
Load and prepare your dataset.
|
||||
|
||||
Returns: Dataset with columns:
|
||||
- 'prompt': List[Dict] with role/content
|
||||
- 'answer': str (ground truth, optional)
|
||||
"""
|
||||
# Example: GSM8K math dataset
|
||||
data = load_dataset('openai/gsm8k', 'main')[split]
|
||||
|
||||
def process_example(x):
|
||||
# Extract ground truth answer
|
||||
answer = x['answer'].split('####')[1].strip() if '####' in x['answer'] else None
|
||||
|
||||
return {
|
||||
'prompt': [
|
||||
{'role': 'system', 'content': SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': x['question']}
|
||||
],
|
||||
'answer': answer
|
||||
}
|
||||
|
||||
return data.map(process_example)
|
||||
|
||||
# ==================== HELPER FUNCTIONS ====================
|
||||
|
||||
def extract_xml_tag(text: str, tag: str) -> str:
|
||||
"""Extract content between XML tags."""
|
||||
pattern = f'<{tag}>(.*?)</{tag}>'
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
return match.group(1).strip() if match else ""
|
||||
|
||||
def extract_answer(text: str) -> str:
|
||||
"""Extract the final answer from structured output."""
|
||||
return extract_xml_tag(text, 'answer')
|
||||
|
||||
# ==================== REWARD FUNCTIONS ====================
|
||||
|
||||
def correctness_reward_func(prompts, completions, answer, **kwargs):
|
||||
"""
|
||||
Reward correct answers.
|
||||
Weight: 2.0 (highest priority)
|
||||
"""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
extracted = [extract_answer(r) for r in responses]
|
||||
return [2.0 if ans == gt else 0.0 for ans, gt in zip(extracted, answer)]
|
||||
|
||||
def format_reward_func(completions, **kwargs):
|
||||
"""
|
||||
Reward proper XML format.
|
||||
Weight: 0.5
|
||||
"""
|
||||
pattern = r'<reasoning>.*?</reasoning>\s*<answer>.*?</answer>'
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
return [0.5 if re.search(pattern, r, re.DOTALL) else 0.0 for r in responses]
|
||||
|
||||
def incremental_format_reward_func(completions, **kwargs):
|
||||
"""
|
||||
Incremental reward for partial format compliance.
|
||||
Weight: up to 0.5
|
||||
"""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
rewards = []
|
||||
|
||||
for r in responses:
|
||||
score = 0.0
|
||||
if '<reasoning>' in r:
|
||||
score += 0.125
|
||||
if '</reasoning>' in r:
|
||||
score += 0.125
|
||||
if '<answer>' in r:
|
||||
score += 0.125
|
||||
if '</answer>' in r:
|
||||
score += 0.125
|
||||
|
||||
# Penalize extra content after closing tag
|
||||
if '</answer>' in r:
|
||||
extra = r.split('</answer>')[-1].strip()
|
||||
score -= len(extra) * 0.001
|
||||
|
||||
rewards.append(score)
|
||||
|
||||
return rewards
|
||||
|
||||
# ==================== MODEL SETUP ====================
|
||||
|
||||
def setup_model_and_tokenizer():
|
||||
"""Load model and tokenizer with optimizations."""
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_NAME,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
def get_peft_config():
|
||||
"""LoRA configuration for parameter-efficient training."""
|
||||
return LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=[
|
||||
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj"
|
||||
],
|
||||
task_type="CAUSAL_LM",
|
||||
lora_dropout=0.05,
|
||||
)
|
||||
|
||||
# ==================== TRAINING ====================
|
||||
|
||||
def main():
|
||||
"""Main training function."""
|
||||
|
||||
# Load data
|
||||
print("Loading dataset...")
|
||||
dataset = get_dataset()
|
||||
print(f"Dataset size: {len(dataset)}")
|
||||
|
||||
# Setup model
|
||||
print("Loading model...")
|
||||
model, tokenizer = setup_model_and_tokenizer()
|
||||
|
||||
# Training configuration
|
||||
training_args = GRPOConfig(
|
||||
output_dir=OUTPUT_DIR,
|
||||
run_name="grpo-training",
|
||||
|
||||
# Learning rate
|
||||
learning_rate=5e-6,
|
||||
adam_beta1=0.9,
|
||||
adam_beta2=0.99,
|
||||
weight_decay=0.1,
|
||||
warmup_ratio=0.1,
|
||||
lr_scheduler_type='cosine',
|
||||
|
||||
# Batch settings
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=4,
|
||||
|
||||
# GRPO specific
|
||||
num_generations=8,
|
||||
max_prompt_length=MAX_PROMPT_LENGTH,
|
||||
max_completion_length=MAX_COMPLETION_LENGTH,
|
||||
|
||||
# Training duration
|
||||
num_train_epochs=1,
|
||||
|
||||
# Optimization
|
||||
bf16=True,
|
||||
optim="adamw_8bit",
|
||||
max_grad_norm=0.1,
|
||||
|
||||
# Logging
|
||||
logging_steps=1,
|
||||
save_steps=100,
|
||||
report_to="wandb", # Change to "none" to disable logging
|
||||
)
|
||||
|
||||
# Initialize trainer
|
||||
trainer = GRPOTrainer(
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
reward_funcs=[
|
||||
incremental_format_reward_func,
|
||||
format_reward_func,
|
||||
correctness_reward_func,
|
||||
],
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
peft_config=get_peft_config(),
|
||||
)
|
||||
|
||||
# Train
|
||||
print("Starting training...")
|
||||
trainer.train()
|
||||
|
||||
# Save final model
|
||||
print(f"Saving model to {OUTPUT_DIR}/final")
|
||||
trainer.save_model(f"{OUTPUT_DIR}/final")
|
||||
|
||||
print("Training complete!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,315 @@
|
||||
---
|
||||
name: miles-rl-training
|
||||
description: Provides guidance for enterprise-grade RL training using miles, a production-ready fork of slime. Use when training large MoE models with FP8/INT4, needing train-inference alignment, or requiring speculative RL for maximum throughput.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Reinforcement Learning, MoE, FP8, INT4, Enterprise, SGLang, Megatron-LM]
|
||||
dependencies: [sglang-router>=0.2.3, ray, torch>=2.0.0, transformers>=4.40.0]
|
||||
---
|
||||
|
||||
# miles: Enterprise-Grade RL for Large-Scale Model Training
|
||||
|
||||
miles is a high-performance, enterprise-ready RL framework optimized for large-scale model post-training. Built as a production fork of slime, it addresses critical challenges in MoE training stability, low-precision training, and train-inference alignment.
|
||||
|
||||
## When to Use miles
|
||||
|
||||
**Choose miles when you need:**
|
||||
- Training 1TB+ MoE models (DeepSeek V3, Qwen3-MoE)
|
||||
- FP8 or INT4 quantization-aware training
|
||||
- Bit-wise identical train-inference alignment
|
||||
- Speculative RL for maximum throughput
|
||||
- Production stability with enterprise support
|
||||
|
||||
**Consider alternatives when:**
|
||||
- You want the research-grade original → use **slime**
|
||||
- You need flexible backend swapping → use **verl**
|
||||
- You want PyTorch-native abstractions → use **torchforge**
|
||||
|
||||
## Key Features
|
||||
|
||||
### Low-Precision Training
|
||||
- **Unified FP8**: End-to-end FP8 for both inference and training
|
||||
- **INT4 QAT**: 1TB models on single-machine VRAM (H200)
|
||||
- **Rollout Routing Replay (R3)**: Bit-wise expert alignment for MoE
|
||||
|
||||
### Performance Optimizations
|
||||
- **Speculative RL**: 25%+ rollout speedup with online SFT draft models
|
||||
- **Zero-Copy Weight Sync**: CUDA IPC zero-copy mapping
|
||||
- **Partial Rollout**: Recycle half-finished trajectories
|
||||
|
||||
### Train-Inference Alignment
|
||||
- **TIS/MIS**: Truncated/Masked Importance Sampling for off-policy correction
|
||||
- **Kernel-level optimization**: FlashAttention-3, DeepGEMM integration
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Recommended: Docker
|
||||
docker pull radixark/miles:latest
|
||||
docker run --rm --gpus all --ipc=host --shm-size=16g \
|
||||
-it radixark/miles:latest /bin/bash
|
||||
|
||||
# From source
|
||||
git clone https://github.com/radixark/miles.git
|
||||
cd miles
|
||||
pip install -r requirements.txt
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
miles inherits slime's configuration system. Basic training:
|
||||
|
||||
```bash
|
||||
python train.py \
|
||||
--advantage-estimator grpo \
|
||||
--model-name qwen3-30b-a3b \
|
||||
--hf-checkpoint /path/to/qwen3-30b-a3b-hf \
|
||||
--rollout-batch-size 512 \
|
||||
--n-samples-per-prompt 8
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Workflow 1: Large MoE Training
|
||||
|
||||
Use this workflow for training large MoE models like DeepSeek V3 or Qwen3-MoE.
|
||||
|
||||
### Prerequisites Checklist
|
||||
- [ ] H100/H200 GPUs with FP8 support
|
||||
- [ ] MoE model (DeepSeek V3, Qwen3-MoE)
|
||||
- [ ] Docker environment with miles
|
||||
|
||||
### Step 1: Environment Setup
|
||||
|
||||
```bash
|
||||
# FP8 block scaling (recommended for stability)
|
||||
export NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1
|
||||
export CUDA_DEVICE_MAX_CONNECTIONS=1
|
||||
```
|
||||
|
||||
### Step 2: Configure Training
|
||||
|
||||
```bash
|
||||
python train.py \
|
||||
--actor-num-gpus-per-node 8 \
|
||||
--rollout-num-gpus 8 \
|
||||
--hf-checkpoint /path/to/deepseek-v3 \
|
||||
--advantage-estimator grpo \
|
||||
--tensor-model-parallel-size 8 \
|
||||
--expert-model-parallel-size 4 \
|
||||
--prompt-data /path/to/data.jsonl \
|
||||
--num-rollout 3000
|
||||
```
|
||||
|
||||
### Verification Checklist
|
||||
- [ ] Model loads without errors
|
||||
- [ ] Routing decisions are consistent
|
||||
- [ ] No NaN/Inf in loss values
|
||||
|
||||
---
|
||||
|
||||
## Workflow 2: Speculative RL Training
|
||||
|
||||
Use this workflow for maximum rollout throughput with EAGLE speculative decoding.
|
||||
|
||||
### How Speculative RL Works
|
||||
|
||||
1. Small draft model generates candidate tokens
|
||||
2. Target model verifies in parallel
|
||||
3. Draft model updated via online SFT to track policy
|
||||
|
||||
### Step 1: Enable Speculative Decoding
|
||||
|
||||
miles supports EAGLE speculative decoding via SGLang:
|
||||
|
||||
```bash
|
||||
python train.py \
|
||||
--actor-num-gpus-per-node 8 \
|
||||
--hf-checkpoint /path/to/target-model \
|
||||
--sglang-speculative-algorithm EAGLE \
|
||||
--sglang-speculative-num-steps 3 \
|
||||
--sglang-speculative-eagle-topk 1 \
|
||||
--sglang-speculative-num-draft-tokens 4 \
|
||||
--sglang-speculative-draft-model-path /path/to/draft-model \
|
||||
--advantage-estimator grpo \
|
||||
--prompt-data /path/to/data.jsonl
|
||||
```
|
||||
|
||||
### Step 2: Enable Online MTP Training (Optional)
|
||||
|
||||
For online SFT of draft model during training:
|
||||
|
||||
```bash
|
||||
--mtp-num-layers 1 \
|
||||
--enable-mtp-training \
|
||||
--mtp-loss-scaling-factor 0.2
|
||||
```
|
||||
|
||||
**Note**: Online MTP training requires a torch dist checkpoint with MTP weights. Add `--mtp-num-layers 1` during checkpoint conversion from HuggingFace.
|
||||
|
||||
### Expected Speedup
|
||||
|
||||
- **Standard rollout**: Baseline
|
||||
- **Speculative RL**: 25-40% faster rollout
|
||||
- **With partial rollout**: Additional 10-15% throughput
|
||||
|
||||
---
|
||||
|
||||
## Configuration Reference
|
||||
|
||||
miles inherits all slime arguments. See [slime API Reference](../slime/references/api-reference.md) for the complete list.
|
||||
|
||||
### Cluster Resources (from slime)
|
||||
|
||||
```bash
|
||||
--actor-num-nodes 1
|
||||
--actor-num-gpus-per-node 8
|
||||
--rollout-num-gpus 8
|
||||
--rollout-num-gpus-per-engine 2
|
||||
--colocate
|
||||
```
|
||||
|
||||
### Megatron Parallelism (from slime)
|
||||
|
||||
```bash
|
||||
--tensor-model-parallel-size 8
|
||||
--pipeline-model-parallel-size 2
|
||||
--expert-model-parallel-size 4 # MoE expert parallelism
|
||||
```
|
||||
|
||||
### Speculative Decoding (miles-specific)
|
||||
|
||||
```bash
|
||||
--sglang-speculative-algorithm EAGLE
|
||||
--sglang-speculative-num-steps 3
|
||||
--sglang-speculative-eagle-topk 1
|
||||
--sglang-speculative-num-draft-tokens 4
|
||||
--sglang-enable-draft-weights-cpu-backup
|
||||
--sglang-speculative-draft-model-path /your/draft/model/path
|
||||
```
|
||||
|
||||
### Online MTP Training (miles-specific)
|
||||
|
||||
```bash
|
||||
--mtp-num-layers 1
|
||||
--enable-mtp-training
|
||||
--mtp-loss-scaling-factor 0.2
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Key Features (Conceptual)
|
||||
|
||||
The following features are documented in miles but specific CLI flags may vary. Consult the miles repository for latest configuration.
|
||||
|
||||
### Unified FP8 Pipeline
|
||||
|
||||
End-to-end FP8 sampling and training that eliminates quantization-induced discrepancy causing RL collapse in MoE models.
|
||||
|
||||
### Rollout Routing Replay (R3)
|
||||
|
||||
Records expert routing decisions during SGLang inference and replays them during Megatron training for bit-wise expert alignment.
|
||||
|
||||
**How R3 Works**:
|
||||
1. During SGLang inference, expert routing decisions are recorded
|
||||
2. Routing decisions stored in `sample.rollout_routed_experts`
|
||||
3. During Megatron training, routing is replayed instead of recomputed
|
||||
4. Ensures identical expert selection between train and inference
|
||||
|
||||
### INT4 Quantization-Aware Training
|
||||
|
||||
Enables single-machine deployment of 1TB+ models (e.g., on H200).
|
||||
|
||||
**Memory Savings with INT4**:
|
||||
|
||||
| Model Size | BF16 VRAM | INT4 VRAM | Reduction |
|
||||
|------------|-----------|-----------|-----------|
|
||||
| 70B | 140GB | 45GB | 3.1x |
|
||||
| 235B | 470GB | 150GB | 3.1x |
|
||||
| 671B | 1.3TB | 420GB | 3.1x |
|
||||
|
||||
### Train-Inference Alignment
|
||||
|
||||
miles achieves "exactly 0 KL divergence" between training and inference through:
|
||||
- Flash Attention 3
|
||||
- DeepGEMM
|
||||
- Batch-invariant kernels from Thinking Machines Lab
|
||||
- `torch.compile` integration
|
||||
|
||||
---
|
||||
|
||||
## Sample Data Structure
|
||||
|
||||
miles uses the same `Sample` dataclass as slime with the `rollout_routed_experts` field for MoE routing replay:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class Sample:
|
||||
prompt: str | list[dict]
|
||||
tokens: list[int]
|
||||
response: str
|
||||
reward: float | dict
|
||||
loss_mask: list[int]
|
||||
status: Status
|
||||
metadata: dict
|
||||
rollout_log_probs: list[float]
|
||||
rollout_routed_experts: list[list[int]] # MoE routing for R3
|
||||
```
|
||||
|
||||
See [slime API Reference](../slime/references/api-reference.md) for the complete Sample definition.
|
||||
|
||||
---
|
||||
|
||||
## Common Issues and Solutions
|
||||
|
||||
### Issue: FP8 Training Collapse
|
||||
|
||||
**Symptoms**: Loss explodes, NaN values
|
||||
|
||||
**Solutions**:
|
||||
- Use block scaling: `export NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1`
|
||||
- Reduce learning rate: `--lr 5e-7`
|
||||
- Ensure MoE routing is consistent between train/inference
|
||||
|
||||
### Issue: Speculative Draft Drift
|
||||
|
||||
**Symptoms**: Low acceptance rate over time
|
||||
|
||||
**Solutions**:
|
||||
- Enable online MTP training to keep draft model aligned
|
||||
- Reduce speculative steps: `--sglang-speculative-num-steps 2`
|
||||
- Use CPU backup: `--sglang-enable-draft-weights-cpu-backup`
|
||||
|
||||
### Issue: Train-Inference Mismatch
|
||||
|
||||
**Symptoms**: Policy divergence, reward collapse
|
||||
|
||||
**Solutions**:
|
||||
- Use TIS for off-policy correction: `--use-tis --tis-threshold 0.9`
|
||||
- Verify log probs match between SGLang and Megatron
|
||||
- Enable R3 for MoE models
|
||||
|
||||
---
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Family | Models | MoE Support |
|
||||
|--------|--------|-------------|
|
||||
| DeepSeek | R1, V3, V3.2 | Full |
|
||||
| Qwen | 2, 2.5, 3 (including MoE) | Full |
|
||||
| Llama | 3, 3.1, 3.3, 4 | Dense only |
|
||||
| Gemma | 2, 3, 3N | Dense only |
|
||||
| GLM | 4.5, 4.6, 4.7 | Dense only |
|
||||
| MiniMax | M2, M2.1 | Full |
|
||||
|
||||
---
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/radixark/miles
|
||||
- **Introduction Blog**: https://lmsys.org/blog/2025-11-19-miles/
|
||||
- **Slime (upstream)**: https://github.com/THUDM/slime
|
||||
- **SGLang**: https://github.com/sgl-project/sglang
|
||||
|
||||
@@ -0,0 +1,141 @@
|
||||
# miles API Reference
|
||||
|
||||
## Overview
|
||||
|
||||
miles is an enterprise-grade RL framework built on slime, adding advanced features for large-scale MoE training:
|
||||
|
||||
- Unified FP8 training and inference
|
||||
- INT4 Quantization-Aware Training
|
||||
- Rollout Routing Replay (R3)
|
||||
- Speculative RL training
|
||||
|
||||
**Note**: miles inherits slime's configuration system. See [slime API Reference](../../slime/references/api-reference.md) for base arguments.
|
||||
|
||||
## Core Data Structures
|
||||
|
||||
miles uses the same `Sample` dataclass as slime with the `rollout_routed_experts` field for MoE routing replay.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
python train.py \
|
||||
--advantage-estimator grpo \
|
||||
--model-name qwen3-30b-a3b \
|
||||
--hf-checkpoint /path/to/qwen3-30b-a3b-hf \
|
||||
--rollout-batch-size 512 \
|
||||
--n-samples-per-prompt 8
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
miles inherits slime's three argument categories (Megatron, SGLang with `--sglang-` prefix, and slime-specific). Key additions:
|
||||
|
||||
### Cluster Resources (inherited from slime)
|
||||
|
||||
```bash
|
||||
--actor-num-nodes 1
|
||||
--actor-num-gpus-per-node 8
|
||||
--rollout-num-gpus 8
|
||||
--rollout-num-gpus-per-engine 2
|
||||
--colocate
|
||||
```
|
||||
|
||||
### Megatron Parallelism (inherited from slime)
|
||||
|
||||
```bash
|
||||
--tensor-model-parallel-size 8
|
||||
--pipeline-model-parallel-size 2
|
||||
--expert-model-parallel-size 4 # MoE expert parallelism
|
||||
```
|
||||
|
||||
### Speculative Decoding
|
||||
|
||||
Verified flags from miles documentation:
|
||||
|
||||
```bash
|
||||
# Basic speculative decoding
|
||||
--sglang-speculative-algorithm EAGLE
|
||||
--sglang-speculative-num-steps 3
|
||||
--sglang-speculative-eagle-topk 1
|
||||
--sglang-speculative-num-draft-tokens 4
|
||||
--sglang-enable-draft-weights-cpu-backup
|
||||
|
||||
# Draft model path
|
||||
--sglang-speculative-draft-model-path /your/draft/model/path
|
||||
|
||||
# Online SFT for draft model (MTP)
|
||||
--mtp-num-layers 1
|
||||
--enable-mtp-training
|
||||
--mtp-loss-scaling-factor 0.2
|
||||
```
|
||||
|
||||
**Note**: Online MTP training requires a torch dist checkpoint with MTP weights. Add `--mtp-num-layers 1` during checkpoint conversion from HuggingFace to torch dist format.
|
||||
|
||||
## Key Features (Conceptual)
|
||||
|
||||
The following features are documented in miles but specific CLI flags are not publicly documented. Consult the miles repository for latest configuration options.
|
||||
|
||||
### Unified FP8 Pipeline
|
||||
|
||||
End-to-end FP8 sampling and training that eliminates quantization-induced discrepancy causing RL collapse in MoE models.
|
||||
|
||||
### Rollout Routing Replay (R3)
|
||||
|
||||
Records expert routing decisions during SGLang inference and replays them during Megatron training for bit-wise expert alignment.
|
||||
|
||||
**How R3 Works**:
|
||||
1. During SGLang inference, expert routing decisions are recorded
|
||||
2. Routing decisions stored in `sample.rollout_routed_experts`
|
||||
3. During Megatron training, routing is replayed instead of recomputed
|
||||
4. Ensures identical expert selection between train and inference
|
||||
|
||||
### INT4 Quantization-Aware Training
|
||||
|
||||
Enables single-machine deployment of 1TB+ models (e.g., on H200).
|
||||
|
||||
**Memory Savings with INT4**:
|
||||
|
||||
| Model Size | BF16 VRAM | INT4 VRAM | Reduction |
|
||||
|------------|-----------|-----------|-----------|
|
||||
| 70B | 140GB | 45GB | 3.1x |
|
||||
| 235B | 470GB | 150GB | 3.1x |
|
||||
| 671B | 1.3TB | 420GB | 3.1x |
|
||||
|
||||
### Train-Inference Alignment
|
||||
|
||||
miles achieves "exactly 0 KL divergence" between training and inference through infrastructure optimizations:
|
||||
- Flash Attention 3
|
||||
- DeepGEMM
|
||||
- Batch-invariant kernels from Thinking Machines Lab
|
||||
- `torch.compile` integration
|
||||
|
||||
### Truncated/Masked Importance Sampling (TIS/MIS)
|
||||
|
||||
Algorithmic corrections for off-policy training. See slime documentation for `--use-tis` flag.
|
||||
|
||||
## Custom Functions
|
||||
|
||||
Same interface as slime:
|
||||
|
||||
```bash
|
||||
--custom-generate-function-path generate.py
|
||||
--custom-rm-path reward.py
|
||||
```
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Family | Models | MoE Support |
|
||||
|--------|--------|-------------|
|
||||
| DeepSeek | R1, V3, V3.2 | Full |
|
||||
| Qwen | 2, 2.5, 3 (including MoE) | Full |
|
||||
| Llama | 3, 3.1, 3.3, 4 | Dense only |
|
||||
| Gemma | 2, 3, 3N | Dense only |
|
||||
| GLM | 4.5, 4.6, 4.7 | Dense only |
|
||||
| MiniMax | M2, M2.1 | Full |
|
||||
|
||||
## Resources
|
||||
|
||||
- GitHub: https://github.com/radixark/miles
|
||||
- Introduction Blog: https://lmsys.org/blog/2025-11-19-miles/
|
||||
- Slime (upstream): https://github.com/THUDM/slime
|
||||
- SGLang: https://github.com/sgl-project/sglang
|
||||
@@ -0,0 +1,352 @@
|
||||
# miles Troubleshooting Guide
|
||||
|
||||
## FP8 Training Issues
|
||||
|
||||
### Issue: FP8 Training Collapse
|
||||
|
||||
**Symptoms**: Loss explodes, NaN values, reward collapses
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Use block scaling**:
|
||||
```bash
|
||||
--fp8-recipe blockwise
|
||||
export NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1
|
||||
```
|
||||
|
||||
2. **Enable R3 for MoE models**:
|
||||
```bash
|
||||
--use-r3
|
||||
```
|
||||
|
||||
3. **Reduce learning rate**:
|
||||
```bash
|
||||
--lr 5e-7 # Reduce from 1e-6
|
||||
```
|
||||
|
||||
4. **Warm up from BF16**:
|
||||
```bash
|
||||
--warmup-steps 100
|
||||
--warmup-precision bf16
|
||||
```
|
||||
|
||||
### Issue: FP8 vs BF16 Accuracy Gap
|
||||
|
||||
**Symptoms**: FP8 model underperforms BF16 baseline
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Use E4M3 format for activations**:
|
||||
```bash
|
||||
--fp8-format e4m3
|
||||
```
|
||||
|
||||
2. **Enable dynamic scaling**:
|
||||
```bash
|
||||
--fp8-dynamic-scaling
|
||||
```
|
||||
|
||||
3. **Skip sensitive layers**:
|
||||
```bash
|
||||
--fp8-skip-layers "lm_head,embed"
|
||||
```
|
||||
|
||||
## Train-Inference Mismatch Issues
|
||||
|
||||
### Issue: Policy Divergence
|
||||
|
||||
**Symptoms**: Model behavior differs between training and inference
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Enable Rollout Routing Replay**:
|
||||
```bash
|
||||
--use-r3
|
||||
```
|
||||
|
||||
2. **Use importance sampling correction**:
|
||||
```bash
|
||||
--use-tis --tis-threshold 0.9
|
||||
```
|
||||
|
||||
3. **Verify log probs match**:
|
||||
```bash
|
||||
--verify-logprobs
|
||||
```
|
||||
|
||||
### Issue: Expert Routing Mismatch (MoE)
|
||||
|
||||
**Symptoms**: Different experts activated during train vs inference
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Enable R3**:
|
||||
```bash
|
||||
--use-r3
|
||||
--r3-buffer-size 1000
|
||||
```
|
||||
|
||||
2. **Use deterministic routing**:
|
||||
```bash
|
||||
--deterministic-expert-routing
|
||||
```
|
||||
|
||||
## INT4 Training Issues
|
||||
|
||||
### Issue: INT4 Accuracy Degradation
|
||||
|
||||
**Symptoms**: Worse performance than BF16 or FP8
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Increase group size**:
|
||||
```bash
|
||||
--int4-group-size 256 # Increase from 128
|
||||
```
|
||||
|
||||
2. **Use mixed precision for sensitive layers**:
|
||||
```bash
|
||||
--int4-skip-layers "lm_head,embed,layer_norm"
|
||||
```
|
||||
|
||||
3. **Warm start from BF16**:
|
||||
```bash
|
||||
--warmup-steps 100
|
||||
--warmup-precision bf16
|
||||
```
|
||||
|
||||
4. **Increase learning rate** (INT4 often needs higher LR):
|
||||
```bash
|
||||
--lr 2e-6 # Increase from 1e-6
|
||||
```
|
||||
|
||||
### Issue: INT4 OOM Despite Expected Savings
|
||||
|
||||
**Symptoms**: Still running out of memory with INT4
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Verify environment variable**:
|
||||
```bash
|
||||
export OPEN_TRAINING_INT4_FAKE_QAT_FLAG=1
|
||||
```
|
||||
|
||||
2. **Check group size alignment**:
|
||||
```bash
|
||||
# Group size must divide hidden dimension evenly
|
||||
--int4-group-size 128 # Must divide hidden_size
|
||||
```
|
||||
|
||||
## Speculative RL Issues
|
||||
|
||||
### Issue: Low Acceptance Rate
|
||||
|
||||
**Symptoms**: Draft model tokens frequently rejected
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Reduce lookahead**:
|
||||
```bash
|
||||
--spec-lookahead 3 # Reduce from 5
|
||||
```
|
||||
|
||||
2. **Update draft more frequently**:
|
||||
```bash
|
||||
--online-sft-interval 5 # Reduce from 10
|
||||
```
|
||||
|
||||
3. **Increase draft learning rate**:
|
||||
```bash
|
||||
--draft-lr 1e-5 # Increase
|
||||
```
|
||||
|
||||
### Issue: Draft Model Drift
|
||||
|
||||
**Symptoms**: Acceptance rate drops over time
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Enable online SFT**:
|
||||
```bash
|
||||
--online-sft-interval 5
|
||||
```
|
||||
|
||||
2. **Use EMA for draft updates**:
|
||||
```bash
|
||||
--draft-ema-decay 0.99
|
||||
```
|
||||
|
||||
3. **Reinitialize draft periodically**:
|
||||
```bash
|
||||
--reinit-draft-interval 1000
|
||||
```
|
||||
|
||||
### Issue: Speculative Training Slower Than Expected
|
||||
|
||||
**Symptoms**: Not achieving expected 25%+ speedup
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Verify draft model is small enough**:
|
||||
```bash
|
||||
# Draft should be 1/4 to 1/10 size of target
|
||||
```
|
||||
|
||||
2. **Check lookahead is optimal**:
|
||||
```bash
|
||||
--spec-lookahead 5 # Sweet spot for most models
|
||||
```
|
||||
|
||||
3. **Profile to find bottleneck**:
|
||||
```bash
|
||||
--profile-speculative
|
||||
```
|
||||
|
||||
## Weight Synchronization Issues
|
||||
|
||||
### Issue: Zero-Copy Sync Failures
|
||||
|
||||
**Symptoms**: Errors with CUDA IPC, weight corruption
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Verify CUDA IPC support**:
|
||||
```bash
|
||||
nvidia-smi topo -m # Check GPU topology
|
||||
```
|
||||
|
||||
2. **Fall back to standard sync**:
|
||||
```bash
|
||||
# Remove --use-zero-copy-sync
|
||||
```
|
||||
|
||||
3. **Increase bucket size**:
|
||||
```bash
|
||||
--sync-bucket-size 2147483648 # 2GB
|
||||
```
|
||||
|
||||
### Issue: Slow Weight Sync Despite Zero-Copy
|
||||
|
||||
**Symptoms**: Weight sync still slow
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Use colocated mode**:
|
||||
```bash
|
||||
--colocate
|
||||
```
|
||||
|
||||
2. **Enable async weight transfer**:
|
||||
```bash
|
||||
--async-weight-sync
|
||||
```
|
||||
|
||||
## MoE-Specific Issues
|
||||
|
||||
### Issue: Expert Load Imbalance
|
||||
|
||||
**Symptoms**: Some experts heavily loaded, others unused
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Enable load balancing loss**:
|
||||
```bash
|
||||
--aux-loss-coef 0.01
|
||||
```
|
||||
|
||||
2. **Use capacity factor**:
|
||||
```bash
|
||||
--moe-capacity-factor 1.25
|
||||
```
|
||||
|
||||
### Issue: Expert Parallelism OOM
|
||||
|
||||
**Symptoms**: OOM with large MoE models
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Increase expert parallelism**:
|
||||
```bash
|
||||
--expert-model-parallel-size 8 # Increase from 4
|
||||
```
|
||||
|
||||
2. **Reduce batch size per GPU**:
|
||||
```bash
|
||||
--micro-batch-size 1
|
||||
```
|
||||
|
||||
3. **Enable expert offloading**:
|
||||
```bash
|
||||
--offload-experts
|
||||
```
|
||||
|
||||
## Multi-Agent Issues
|
||||
|
||||
### Issue: Co-Evolution Instability
|
||||
|
||||
**Symptoms**: Agents oscillate or one dominates
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Use alternating updates**:
|
||||
```yaml
|
||||
co_evolution:
|
||||
strategy: alternating
|
||||
```
|
||||
|
||||
2. **Reduce co-evolution frequency**:
|
||||
```bash
|
||||
--co-evolution-interval 20 # Increase from 10
|
||||
```
|
||||
|
||||
3. **Add population diversity**:
|
||||
```yaml
|
||||
co_evolution:
|
||||
population_size: 4
|
||||
```
|
||||
|
||||
## Debugging Tips
|
||||
|
||||
### Enable Verbose Logging
|
||||
|
||||
```bash
|
||||
--log-level DEBUG
|
||||
export MILES_DEBUG=1
|
||||
```
|
||||
|
||||
### Check FP8 Tensors
|
||||
|
||||
```python
|
||||
# Verify FP8 is active
|
||||
for name, param in model.named_parameters():
|
||||
print(f"{name}: {param.dtype}")
|
||||
```
|
||||
|
||||
### Profile Training
|
||||
|
||||
```bash
|
||||
--profile
|
||||
--profile-dir /path/to/profile
|
||||
```
|
||||
|
||||
### Verify R3 Is Working
|
||||
|
||||
```python
|
||||
# Check routing is being recorded
|
||||
sample = samples[0]
|
||||
assert sample.rollout_routed_experts is not None
|
||||
assert len(sample.rollout_routed_experts) > 0
|
||||
```
|
||||
|
||||
### Monitor GPU Memory
|
||||
|
||||
```bash
|
||||
watch -n 1 nvidia-smi
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- GitHub Issues: https://github.com/radixark/miles/issues
|
||||
- Unified FP8 Blog: https://lmsys.org/blog/2025-11-25-fp8-rl/
|
||||
- Train-Inference Mismatch Tutorial: https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/slime/mismatch/blog-en.md
|
||||
- SGLang Discord: Community support
|
||||
@@ -0,0 +1,249 @@
|
||||
---
|
||||
name: openrlhf-training
|
||||
description: High-performance RLHF framework with Ray+vLLM acceleration. Use for PPO, GRPO, RLOO, DPO training of large models (7B-70B+). Built on Ray, vLLM, ZeRO-3. 2× faster than DeepSpeedChat with distributed architecture and GPU resource sharing.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Post-Training, OpenRLHF, RLHF, PPO, GRPO, RLOO, DPO, Ray, vLLM, Distributed Training, Large Models, ZeRO-3]
|
||||
dependencies: [openrlhf, ray, vllm, torch, transformers, deepspeed]
|
||||
---
|
||||
|
||||
# OpenRLHF - High-Performance RLHF Training
|
||||
|
||||
## Quick start
|
||||
|
||||
OpenRLHF is a Ray-based RLHF framework optimized for distributed training with vLLM inference acceleration.
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
# Launch Docker container
|
||||
docker run --runtime=nvidia -it --rm --shm-size="10g" --cap-add=SYS_ADMIN \
|
||||
-v $PWD:/openrlhf nvcr.io/nvidia/pytorch:25.02-py3 bash
|
||||
|
||||
# Uninstall conflicts
|
||||
sudo pip uninstall xgboost transformer_engine flash_attn pynvml -y
|
||||
|
||||
# Install OpenRLHF with vLLM
|
||||
pip install openrlhf[vllm]
|
||||
```
|
||||
|
||||
**PPO Training** (Hybrid Engine):
|
||||
```bash
|
||||
ray start --head --node-ip-address 0.0.0.0 --num-gpus 8
|
||||
|
||||
ray job submit --address="http://127.0.0.1:8265" \
|
||||
--runtime-env-json='{"working_dir": "/openrlhf"}' \
|
||||
-- python3 -m openrlhf.cli.train_ppo_ray \
|
||||
--ref_num_nodes 1 --ref_num_gpus_per_node 8 \
|
||||
--reward_num_nodes 1 --reward_num_gpus_per_node 8 \
|
||||
--critic_num_nodes 1 --critic_num_gpus_per_node 8 \
|
||||
--actor_num_nodes 1 --actor_num_gpus_per_node 8 \
|
||||
--vllm_num_engines 4 --vllm_tensor_parallel_size 2 \
|
||||
--colocate_all_models \
|
||||
--vllm_gpu_memory_utilization 0.5 \
|
||||
--pretrain OpenRLHF/Llama-3-8b-sft-mixture \
|
||||
--reward_pretrain OpenRLHF/Llama-3-8b-rm-700k \
|
||||
--save_path ./output/llama3-8b-rlhf \
|
||||
--micro_train_batch_size 8 --train_batch_size 128 \
|
||||
--micro_rollout_batch_size 16 --rollout_batch_size 1024 \
|
||||
--max_epochs 1 --prompt_max_len 1024 --generate_max_len 1024 \
|
||||
--zero_stage 3 --bf16 \
|
||||
--actor_learning_rate 5e-7 --critic_learning_rate 9e-6 \
|
||||
--init_kl_coef 0.01 --normalize_reward \
|
||||
--gradient_checkpointing --packing_samples \
|
||||
--vllm_enable_sleep --deepspeed_enable_sleep
|
||||
```
|
||||
|
||||
**GRPO Training** (Group Normalized Policy Optimization):
|
||||
```bash
|
||||
# Same command as PPO, but add:
|
||||
--advantage_estimator group_norm
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Full RLHF pipeline (SFT → Reward Model → PPO)
|
||||
|
||||
**Step 1: Train reward model** (DPO):
|
||||
```bash
|
||||
deepspeed --module openrlhf.cli.train_rm \
|
||||
--save_path ./output/llama3-8b-rm \
|
||||
--save_steps -1 --logging_steps 1 \
|
||||
--eval_steps -1 --train_batch_size 256 \
|
||||
--micro_train_batch_size 1 --pretrain meta-llama/Meta-Llama-3-8B \
|
||||
--bf16 --max_epochs 1 --max_len 8192 \
|
||||
--zero_stage 3 --learning_rate 9e-6 \
|
||||
--dataset OpenRLHF/preference_dataset_mixture2_and_safe_pku \
|
||||
--apply_chat_template --chosen_key chosen \
|
||||
--rejected_key rejected --flash_attn --gradient_checkpointing
|
||||
```
|
||||
|
||||
**Step 2: PPO training**:
|
||||
```bash
|
||||
ray start --head --node-ip-address 0.0.0.0 --num-gpus 8
|
||||
|
||||
ray job submit --address="http://127.0.0.1:8265" \
|
||||
-- python3 -m openrlhf.cli.train_ppo_ray \
|
||||
--ref_num_nodes 1 --ref_num_gpus_per_node 8 \
|
||||
--reward_num_nodes 1 --reward_num_gpus_per_node 8 \
|
||||
--critic_num_nodes 1 --critic_num_gpus_per_node 8 \
|
||||
--actor_num_nodes 1 --actor_num_gpus_per_node 8 \
|
||||
--vllm_num_engines 4 --vllm_tensor_parallel_size 2 \
|
||||
--colocate_all_models \
|
||||
--pretrain OpenRLHF/Llama-3-8b-sft-mixture \
|
||||
--reward_pretrain ./output/llama3-8b-rm \
|
||||
--save_path ./output/llama3-8b-ppo \
|
||||
--micro_train_batch_size 8 --train_batch_size 128 \
|
||||
--micro_rollout_batch_size 16 --rollout_batch_size 1024 \
|
||||
--max_epochs 1 --prompt_max_len 1024 --generate_max_len 1024 \
|
||||
--zero_stage 3 --bf16 \
|
||||
--actor_learning_rate 5e-7 --critic_learning_rate 9e-6 \
|
||||
--init_kl_coef 0.01 --normalize_reward \
|
||||
--vllm_enable_sleep --deepspeed_enable_sleep
|
||||
```
|
||||
|
||||
### Workflow 2: GRPO training (no critic model needed)
|
||||
|
||||
Memory-efficient alternative to PPO:
|
||||
|
||||
```bash
|
||||
ray job submit --address="http://127.0.0.1:8265" \
|
||||
-- python3 -m openrlhf.cli.train_ppo_ray \
|
||||
--advantage_estimator group_norm \
|
||||
--ref_num_nodes 1 --ref_num_gpus_per_node 8 \
|
||||
--reward_num_nodes 1 --reward_num_gpus_per_node 8 \
|
||||
--actor_num_nodes 1 --actor_num_gpus_per_node 8 \
|
||||
--vllm_num_engines 4 --vllm_tensor_parallel_size 2 \
|
||||
--colocate_all_models \
|
||||
--pretrain OpenRLHF/Llama-3-8b-sft-mixture \
|
||||
--reward_pretrain OpenRLHF/Llama-3-8b-rm-700k \
|
||||
--save_path ./output/llama3-8b-grpo \
|
||||
--micro_train_batch_size 8 --train_batch_size 128 \
|
||||
--micro_rollout_batch_size 16 --rollout_batch_size 1024 \
|
||||
--max_epochs 1 --bf16 \
|
||||
--actor_learning_rate 5e-7 \
|
||||
--init_kl_coef 0.01 --use_kl_loss --kl_estimator k3 \
|
||||
--normalize_reward --no_advantage_std_norm
|
||||
```
|
||||
|
||||
**Key GRPO parameters**:
|
||||
- `--advantage_estimator group_norm` - Enables GRPO
|
||||
- `--use_kl_loss` - KL loss from GRPO paper
|
||||
- `--kl_estimator k3` - Loss function (k2 ≈ k1)
|
||||
- `--no_advantage_std_norm` - Disables std normalization
|
||||
|
||||
### Workflow 3: DPO training (preference optimization)
|
||||
|
||||
Simpler alternative without reward model:
|
||||
|
||||
```bash
|
||||
deepspeed --module openrlhf.cli.train_dpo \
|
||||
--save_path ./output/llama3-8b-dpo \
|
||||
--save_steps -1 --logging_steps 1 \
|
||||
--eval_steps -1 --train_batch_size 256 \
|
||||
--micro_train_batch_size 2 --pretrain meta-llama/Meta-Llama-3-8B \
|
||||
--bf16 --max_epochs 1 --max_len 8192 \
|
||||
--zero_stage 3 --learning_rate 5e-7 --beta 0.1 \
|
||||
--dataset OpenRLHF/preference_dataset_mixture2_and_safe_pku \
|
||||
--apply_chat_template --chosen_key chosen \
|
||||
--rejected_key rejected --flash_attn --gradient_checkpointing
|
||||
```
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use OpenRLHF when**:
|
||||
- Training large models (7B-70B+) with RL
|
||||
- Need vLLM inference acceleration
|
||||
- Want distributed architecture with Ray
|
||||
- Have multi-node GPU cluster
|
||||
- Need PPO/GRPO/RLOO/DPO in one framework
|
||||
|
||||
**Algorithm selection**:
|
||||
- **PPO**: Maximum control, best for complex rewards
|
||||
- **GRPO**: Memory-efficient, no critic needed
|
||||
- **RLOO**: Modified PPO with per-token KL
|
||||
- **REINFORCE++**: More stable than GRPO, faster than PPO
|
||||
- **DPO**: Simplest, no reward model needed
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **TRL**: Single-node training, simpler API
|
||||
- **veRL**: ByteDance's framework for 671B models
|
||||
- **DeepSpeedChat**: Integrated with DeepSpeed ecosystem
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: GPU OOM with large models**
|
||||
|
||||
Disable model colocation:
|
||||
```bash
|
||||
# Remove --colocate_all_models flag
|
||||
# Allocate separate GPUs for each model
|
||||
--actor_num_gpus_per_node 8 \
|
||||
--critic_num_gpus_per_node 8 \
|
||||
--reward_num_gpus_per_node 8 \
|
||||
--ref_num_gpus_per_node 8
|
||||
```
|
||||
|
||||
**Issue: DeepSpeed GPU index out of range**
|
||||
|
||||
Set environment variable:
|
||||
```bash
|
||||
export RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=1
|
||||
```
|
||||
|
||||
**Issue: Training instability**
|
||||
|
||||
Use Hybrid Engine instead of async:
|
||||
```bash
|
||||
--colocate_all_models \
|
||||
--vllm_enable_sleep \
|
||||
--deepspeed_enable_sleep
|
||||
```
|
||||
|
||||
Adjust KL coefficient:
|
||||
```bash
|
||||
--init_kl_coef 0.05 # Increase from 0.01
|
||||
```
|
||||
|
||||
**Issue: Slow generation during PPO**
|
||||
|
||||
Enable vLLM acceleration:
|
||||
```bash
|
||||
--vllm_num_engines 4 \
|
||||
--vllm_tensor_parallel_size 2 \
|
||||
--vllm_gpu_memory_utilization 0.5
|
||||
```
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**Hybrid Engine GPU sharing**: See [references/hybrid-engine.md](references/hybrid-engine.md) for vLLM sleep mode, DeepSpeed sleep mode, and optimal node allocation.
|
||||
|
||||
**Algorithm comparison**: See [references/algorithm-comparison.md](references/algorithm-comparison.md) for PPO vs GRPO vs RLOO vs REINFORCE++ benchmarks and hyperparameters.
|
||||
|
||||
**Multi-node setup**: See [references/multi-node-training.md](references/multi-node-training.md) for Ray cluster configuration and fault tolerance.
|
||||
|
||||
**Custom reward functions**: See [references/custom-rewards.md](references/custom-rewards.md) for reinforced fine-tuning and agent RLHF.
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **GPU**: NVIDIA A100/H100 recommended
|
||||
- **VRAM**:
|
||||
- 7B model: 8× A100 40GB (Hybrid Engine)
|
||||
- 70B model: 48× A100 80GB (vLLM:Actor:Critic = 1:1:1)
|
||||
- **Multi-node**: Ray cluster with InfiniBand recommended
|
||||
- **Docker**: NVIDIA PyTorch container 25.02+
|
||||
|
||||
**Performance**:
|
||||
- 2× faster than DeepSpeedChat
|
||||
- vLLM inference acceleration
|
||||
- Hybrid Engine minimizes GPU idle time
|
||||
|
||||
## Resources
|
||||
|
||||
- Docs: https://github.com/OpenRLHF/OpenRLHF
|
||||
- Paper: https://arxiv.org/abs/2405.11143
|
||||
- Examples: https://github.com/OpenRLHF/OpenRLHF/tree/main/examples
|
||||
- Discord: Community support
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,404 @@
|
||||
# Algorithm Comparison
|
||||
|
||||
Complete guide to RL algorithms in OpenRLHF: PPO, REINFORCE++, GRPO, RLOO, and their variants.
|
||||
|
||||
## Overview
|
||||
|
||||
OpenRLHF supports 6 RL algorithms selectable via `--advantage_estimator`:
|
||||
- **gae** - PPO with Generalized Advantage Estimation
|
||||
- **reinforce** - REINFORCE++ (PPO optimizations without critic)
|
||||
- **reinforce_baseline** - REINFORCE++ with baseline
|
||||
- **group_norm** - GRPO (Group Normalized Policy Optimization)
|
||||
- **dr_grpo** - Dr. GRPO (GRPO without std normalization)
|
||||
- **rloo** - Reinforcement Learning with Online Off-policy Correction
|
||||
|
||||
## Algorithm Details
|
||||
|
||||
### PPO (Proximal Policy Optimization)
|
||||
|
||||
**Formula**:
|
||||
```
|
||||
loss = -min(ratio * advantages, clip(ratio, 1-ε, 1+ε) * advantages)
|
||||
ratio = π_new(a|s) / π_old(a|s)
|
||||
```
|
||||
|
||||
**Characteristics**:
|
||||
- **Stability**: High (clipped objective prevents large updates)
|
||||
- **Memory**: High (stores actor + critic experiences)
|
||||
- **Speed**: Medium (critic training overhead)
|
||||
- **Requires**: Critic network for value estimation
|
||||
|
||||
**Implementation**:
|
||||
```python
|
||||
surr1 = ratio * advantages
|
||||
surr2 = ratio.clamp(1 - clip_eps_low, 1 + clip_eps_high) * advantages
|
||||
loss = -torch.min(surr1, surr2)
|
||||
```
|
||||
|
||||
**When to use**:
|
||||
- General-purpose RLHF
|
||||
- Complex reward functions
|
||||
- Need stable training
|
||||
|
||||
**Hyperparameters**:
|
||||
```bash
|
||||
--advantage_estimator gae # Enable PPO
|
||||
--clip_eps_low 0.2 # Clipping lower bound
|
||||
--clip_eps_high 0.2 # Clipping upper bound
|
||||
--actor_learning_rate 1e-6
|
||||
--critic_learning_rate 9e-6
|
||||
--init_kl_coef 0.01
|
||||
```
|
||||
|
||||
### REINFORCE++
|
||||
|
||||
**Formula**:
|
||||
```
|
||||
loss = -ratio * advantages (with PPO-clip)
|
||||
advantages = cumulative_returns - baseline
|
||||
```
|
||||
|
||||
**Characteristics**:
|
||||
- **Stability**: Higher than GRPO
|
||||
- **Memory**: Lower (no critic network)
|
||||
- **Speed**: Faster than PPO
|
||||
- **Requires**: No critic network
|
||||
|
||||
**Key innovation**: Integrates PPO optimizations (advantage normalization, PPO-clip loss) into REINFORCE while eliminating critic network overhead.
|
||||
|
||||
**When to use**:
|
||||
- Want PPO stability without critic
|
||||
- Limited memory budget
|
||||
- Fast training priority
|
||||
|
||||
**Hyperparameters**:
|
||||
```bash
|
||||
--advantage_estimator reinforce
|
||||
--critic_pretrain None # No critic needed
|
||||
--init_kl_coef 0.01
|
||||
--actor_learning_rate 1e-6
|
||||
```
|
||||
|
||||
### REINFORCE++-baseline
|
||||
|
||||
**Formula**:
|
||||
```
|
||||
rewards = rewards - mean(rewards_same_prompt)
|
||||
```
|
||||
|
||||
**Characteristics**:
|
||||
- **Stability**: Very high
|
||||
- **Memory**: Lower (no critic)
|
||||
- **Speed**: Faster than PPO
|
||||
- **Requires**: Multiple samples per prompt
|
||||
|
||||
**Key innovation**: Uses mean reward of multiple samples from same prompt as baseline to reshape rewards.
|
||||
|
||||
**When to use**:
|
||||
- RLVR (Reinforcement Learning via Verifier Rewards) settings
|
||||
- Reward patterns vary (0/1/-0.5)
|
||||
- Multiple samples per prompt available
|
||||
|
||||
**Hyperparameters**:
|
||||
```bash
|
||||
--advantage_estimator reinforce_baseline
|
||||
--n_samples_per_prompt 4 # Must be > 1
|
||||
--init_kl_coef 0.01
|
||||
```
|
||||
|
||||
### GRPO (Group Normalized Policy Optimization)
|
||||
|
||||
**Formula**:
|
||||
```
|
||||
rewards = (rewards - mean(rewards)) / (std(rewards) + 1e-9)
|
||||
loss = -ratio * normalized_advantages
|
||||
KL loss (optional): k1, k2, or k3 estimator
|
||||
```
|
||||
|
||||
**Characteristics**:
|
||||
- **Stability**: Lower than REINFORCE++
|
||||
- **Memory**: Lower (no critic)
|
||||
- **Speed**: Fast
|
||||
- **Requires**: Group reward normalization
|
||||
|
||||
**Key innovation**: Group-based advantage normalization with optional KL loss.
|
||||
|
||||
**When to use**:
|
||||
- Exploring policy optimization variants
|
||||
- Need reward normalization
|
||||
- Memory-constrained
|
||||
|
||||
**Hyperparameters**:
|
||||
```bash
|
||||
--advantage_estimator group_norm
|
||||
--use_kl_loss # Enable KL loss
|
||||
--kl_estimator k3 # k3 for loss, k2 ≈ k1
|
||||
--init_kl_coef 0.01
|
||||
--no_advantage_std_norm # Optional: disable std norm
|
||||
```
|
||||
|
||||
**KL estimator variance**:
|
||||
- **k3**: Larger variance under categorical distribution
|
||||
- **k1, k2**: Similar variance, k2 ≈ k1 for loss
|
||||
|
||||
### Dr. GRPO
|
||||
|
||||
**Formula**:
|
||||
```
|
||||
rewards = rewards - mean(rewards) # No std normalization
|
||||
```
|
||||
|
||||
**Characteristics**:
|
||||
- **Stability**: Similar to GRPO
|
||||
- **Memory**: Lower (no critic)
|
||||
- **Speed**: Fast
|
||||
- **Requires**: Group mean normalization only
|
||||
|
||||
**Key innovation**: Removes local group normalization `/std` from GRPO (not needed in RL variance reduction theory).
|
||||
|
||||
**When to use**:
|
||||
- GRPO variant experimentation
|
||||
- Avoid std normalization issues
|
||||
|
||||
**Hyperparameters**:
|
||||
```bash
|
||||
--advantage_estimator dr_grpo
|
||||
--init_kl_coef 0.01
|
||||
```
|
||||
|
||||
### RLOO (RL with Online Off-policy Correction)
|
||||
|
||||
**Formula**:
|
||||
```
|
||||
baseline = (sum(rewards) - rewards) / (n_samples - 1)
|
||||
rewards = rewards - baseline
|
||||
loss = -ratio * advantages (with PPO-clip)
|
||||
```
|
||||
|
||||
**Characteristics**:
|
||||
- **Stability**: High (PPO-clip)
|
||||
- **Memory**: Lower (no critic)
|
||||
- **Speed**: Fast
|
||||
- **Requires**: Multiple samples per prompt, per-token KL
|
||||
|
||||
**Key innovation**: Incorporates per-token KL reward and PPO-clip loss.
|
||||
|
||||
**When to use**:
|
||||
- Need per-token KL rewards
|
||||
- Want PPO stability without critic
|
||||
- Multiple samples per prompt
|
||||
|
||||
**Hyperparameters**:
|
||||
```bash
|
||||
--advantage_estimator rloo
|
||||
--n_samples_per_prompt 4 # Must be > 1
|
||||
--init_kl_coef 0.01
|
||||
```
|
||||
|
||||
## Comparison Table
|
||||
|
||||
| Algorithm | Critic | Stability | Memory | Speed | Best For |
|
||||
|-----------|--------|-----------|--------|-------|----------|
|
||||
| PPO | ✅ Yes | ⭐⭐⭐⭐⭐ | High | Medium | General purpose |
|
||||
| REINFORCE++ | ❌ No | ⭐⭐⭐⭐ | Low | **Fast** | Critic-free PPO |
|
||||
| REINFORCE++-baseline | ❌ No | ⭐⭐⭐⭐⭐ | Low | **Fast** | RLVR settings |
|
||||
| GRPO | ❌ No | ⭐⭐⭐ | Low | Fast | Reward normalization |
|
||||
| Dr. GRPO | ❌ No | ⭐⭐⭐ | Low | Fast | GRPO variant |
|
||||
| RLOO | ❌ No | ⭐⭐⭐⭐ | Low | Fast | Per-token KL |
|
||||
|
||||
## Experience Data Structure
|
||||
|
||||
**PPO (with critic)**:
|
||||
```python
|
||||
@dataclass
|
||||
class Experience:
|
||||
sequences: torch.Tensor # Token sequences
|
||||
attention_mask: torch.Tensor # Attention masks
|
||||
action_mask: torch.Tensor # Action masks
|
||||
action_log_probs: torch.Tensor # Log π(a|s)
|
||||
values: torch.Tensor # Critic value estimates
|
||||
returns: torch.Tensor # Cumulative returns
|
||||
advantages: torch.Tensor # GAE advantages
|
||||
reward: float # Total reward
|
||||
kl: torch.Tensor # KL divergence
|
||||
```
|
||||
|
||||
**REINFORCE++ (no critic)**:
|
||||
```python
|
||||
# No values, returns, or advantages stored
|
||||
# Only sequences, log_probs, and rewards
|
||||
```
|
||||
|
||||
## Memory Comparison (7B Model)
|
||||
|
||||
| Algorithm | Components | Memory (8× A100) |
|
||||
|-----------|-----------|------------------|
|
||||
| PPO | Actor + Critic + Reward + Ref | ~40GB |
|
||||
| REINFORCE++ | Actor + Reward + Ref | ~28GB |
|
||||
| GRPO | Actor + Reward + Ref | ~28GB |
|
||||
| RLOO | Actor + Reward + Ref | ~28GB |
|
||||
|
||||
**Savings**: ~30% memory reduction without critic
|
||||
|
||||
## Speed Comparison
|
||||
|
||||
**Relative training time** (7B model, 1000 steps):
|
||||
- PPO: 1.0× baseline
|
||||
- REINFORCE++: **0.75×** (25% faster)
|
||||
- GRPO: 0.80×
|
||||
- RLOO: 0.80×
|
||||
|
||||
**Why REINFORCE++ is faster**:
|
||||
- No critic training
|
||||
- No value function updates
|
||||
- Fewer backward passes
|
||||
|
||||
## Choosing an Algorithm
|
||||
|
||||
### Decision Tree
|
||||
|
||||
```
|
||||
Need maximum stability?
|
||||
├─ Yes → PPO (with critic)
|
||||
└─ No ↓
|
||||
|
||||
Have multiple samples per prompt?
|
||||
├─ Yes ↓
|
||||
│ └─ RLVR setting with varying rewards?
|
||||
│ ├─ Yes → REINFORCE++-baseline
|
||||
│ └─ No → RLOO (if need per-token KL)
|
||||
└─ No ↓
|
||||
|
||||
Want faster than PPO?
|
||||
└─ Yes → REINFORCE++ (most stable critic-free)
|
||||
|
||||
Experimenting with normalization?
|
||||
└─ Yes → GRPO or Dr. GRPO
|
||||
```
|
||||
|
||||
### By Use Case
|
||||
|
||||
**Production deployment**:
|
||||
```bash
|
||||
# Maximum stability
|
||||
--advantage_estimator gae # PPO
|
||||
--clip_eps_low 0.2
|
||||
--init_kl_coef 0.01
|
||||
```
|
||||
|
||||
**Memory-constrained**:
|
||||
```bash
|
||||
# No critic, stable
|
||||
--advantage_estimator reinforce # REINFORCE++
|
||||
--critic_pretrain None
|
||||
```
|
||||
|
||||
**RLVR / Verification rewards**:
|
||||
```bash
|
||||
# Baseline reward shaping
|
||||
--advantage_estimator reinforce_baseline
|
||||
--n_samples_per_prompt 4
|
||||
```
|
||||
|
||||
**Research / Experimentation**:
|
||||
```bash
|
||||
# Explore GRPO variants
|
||||
--advantage_estimator group_norm
|
||||
--use_kl_loss --kl_estimator k3
|
||||
```
|
||||
|
||||
## Advanced Configuration
|
||||
|
||||
### Reward Normalization
|
||||
|
||||
**PPO (no manual normalization)**:
|
||||
```bash
|
||||
--advantage_estimator gae
|
||||
# GAE handles advantage normalization
|
||||
```
|
||||
|
||||
**GRPO (group normalization)**:
|
||||
```bash
|
||||
--advantage_estimator group_norm
|
||||
--normalize_reward # Optional additional normalization
|
||||
```
|
||||
|
||||
**Disable std normalization**:
|
||||
```bash
|
||||
--no_advantage_std_norm # Keep mean norm only
|
||||
```
|
||||
|
||||
### KL Penalty Configuration
|
||||
|
||||
**All algorithms support**:
|
||||
```bash
|
||||
--init_kl_coef 0.01 # Initial KL coefficient
|
||||
--kl_target 0.1 # Target KL divergence
|
||||
--kl_horizon 10000 # Steps to reach target
|
||||
```
|
||||
|
||||
**GRPO-specific**:
|
||||
```bash
|
||||
--use_kl_loss # Enable KL loss term
|
||||
--kl_estimator k3 # Loss function choice
|
||||
```
|
||||
|
||||
### Clipping Configuration
|
||||
|
||||
**PPO clipping**:
|
||||
```bash
|
||||
--clip_eps_low 0.2 # Lower bound
|
||||
--clip_eps_high 0.2 # Upper bound
|
||||
```
|
||||
|
||||
**Reward clipping**:
|
||||
```bash
|
||||
--reward_clip_range 10.0 # Clip rewards to [-10, 10]
|
||||
```
|
||||
|
||||
## Common Issues
|
||||
|
||||
### PPO Instability
|
||||
|
||||
**Symptom**: Large policy updates, divergence
|
||||
|
||||
**Solution**: Reduce clipping range
|
||||
```bash
|
||||
--clip_eps_low 0.1 # Reduce from 0.2
|
||||
--clip_eps_high 0.1
|
||||
```
|
||||
|
||||
### GRPO High Variance
|
||||
|
||||
**Symptom**: Unstable training with GRPO
|
||||
|
||||
**Solution**: Switch to REINFORCE++
|
||||
```bash
|
||||
--advantage_estimator reinforce # More stable
|
||||
```
|
||||
|
||||
### Memory OOM with PPO
|
||||
|
||||
**Symptom**: OOM during critic training
|
||||
|
||||
**Solution**: Switch to critic-free
|
||||
```bash
|
||||
--advantage_estimator reinforce # No critic
|
||||
--critic_pretrain None
|
||||
```
|
||||
|
||||
### RLOO/Baseline Requires Multiple Samples
|
||||
|
||||
**Symptom**: `AssertionError: n_samples_per_prompt must be > 1`
|
||||
|
||||
**Solution**:
|
||||
```bash
|
||||
--n_samples_per_prompt 4 # Minimum 2, recommended 4-8
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- PPO paper: https://arxiv.org/abs/1707.06347
|
||||
- GRPO paper: https://arxiv.org/abs/2402.03300
|
||||
- OpenRLHF: https://github.com/OpenRLHF/OpenRLHF
|
||||
- OpenRLHF paper: https://arxiv.org/abs/2405.11143
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user