Files
ara/orchestra-skills/01-model-architecture/mamba/references/training-guide.md
T

8.8 KiB

Mamba Training Guide

Training from Scratch

Setup Environment

# Install dependencies
pip install torch>=1.12.0 --extra-index-url https://download.pytorch.org/whl/cu116
pip install packaging ninja
pip install causal-conv1d>=1.1.0
pip install mamba-ssm

# Verify CUDA
python -c "import torch; print(torch.cuda.is_available())"

Basic Training Loop

import torch
from mamba_ssm import Mamba
from torch.utils.data import DataLoader

# Model setup
model = Mamba(
    d_model=512,
    d_state=16,
    d_conv=4,
    expand=2
).cuda()

# Optimizer (same as GPT)
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=6e-4,
    betas=(0.9, 0.95),
    weight_decay=0.1
)

# Training loop
for batch in dataloader:
    inputs, targets = batch
    inputs, targets = inputs.cuda(), targets.cuda()

    # Forward
    logits = model(inputs)
    loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))

    # Backward
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()

Distributed Training

Single-Node Multi-GPU (DDP)

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

# Initialize process group
dist.init_process_group("nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)

# Wrap model
model = Mamba(...).cuda()
model = DDP(model, device_ids=[local_rank])

# Train
optimizer = torch.optim.AdamW(model.parameters(), lr=6e-4)
for batch in dataloader:
    loss = compute_loss(model, batch)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Launch:

torchrun --nproc_per_node=8 train.py

Multi-Node Training

# Node 0 (master)
torchrun --nproc_per_node=8 \
  --nnodes=4 --node_rank=0 \
  --master_addr=$MASTER_ADDR --master_port=29500 \
  train.py

# Node 1-3 (workers)
torchrun --nproc_per_node=8 \
  --nnodes=4 --node_rank=$NODE_RANK \
  --master_addr=$MASTER_ADDR --master_port=29500 \
  train.py

Mixed Precision Training

from torch.cuda.amp import autocast, GradScaler

# BF16 (no scaler needed on A100/H100)
for batch in dataloader:
    with autocast(dtype=torch.bfloat16):
        logits = model(inputs)
        loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

FP16 (with gradient scaling)

scaler = GradScaler()

for batch in dataloader:
    with autocast(dtype=torch.float16):
        logits = model(inputs)
        loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))

    optimizer.zero_grad()
    scaler.scale(loss).backward()
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    scaler.step(optimizer)
    scaler.update()

Hyperparameter Recommendations

Learning Rate Schedule

# Cosine decay with warmup (GPT-3 style)
def get_lr(it, warmup_iters=2000, lr_decay_iters=600000):
    max_lr = 6e-4
    min_lr = 6e-5

    # Warmup
    if it < warmup_iters:
        return max_lr * it / warmup_iters

    # Decay
    if it > lr_decay_iters:
        return min_lr

    # Cosine
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return min_lr + coeff * (max_lr - min_lr)

# Apply in training loop
for it, batch in enumerate(dataloader):
    lr = get_lr(it)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

Batch Size Recommendations

Model Size Per-GPU Batch Gradient Accum Effective Batch GPUs
130M 32 4 1024 8
370M 16 8 1024 8
790M 8 8 512 8
1.4B 4 16 512 8
2.8B 2 16 256 8
# Gradient accumulation
accumulation_steps = 8
optimizer.zero_grad()

for i, batch in enumerate(dataloader):
    loss = compute_loss(model, batch) / accumulation_steps
    loss.backward()

    if (i + 1) % accumulation_steps == 0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        optimizer.zero_grad()

Optimizer Configuration

# AdamW (recommended)
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=6e-4,           # Peak learning rate
    betas=(0.9, 0.95), # Standard for LLMs
    eps=1e-8,
    weight_decay=0.1   # Important for generalization
)

# Weight decay exemptions (optional)
decay = set()
no_decay = set()
for name, param in model.named_parameters():
    if 'norm' in name or 'bias' in name:
        no_decay.add(param)
    else:
        decay.add(param)

optimizer = torch.optim.AdamW([
    {'params': list(decay), 'weight_decay': 0.1},
    {'params': list(no_decay), 'weight_decay': 0.0}
], lr=6e-4, betas=(0.9, 0.95))

Memory Optimization

Gradient Checkpointing

from torch.utils.checkpoint import checkpoint

class MambaBlock(nn.Module):
    def __init__(self, d_model, use_checkpoint=False):
        super().__init__()
        self.use_checkpoint = use_checkpoint
        self.norm = RMSNorm(d_model)
        self.mamba = Mamba(d_model)

    def forward(self, x):
        if self.use_checkpoint and self.training:
            return x + checkpoint(self._forward, x, use_reentrant=False)
        return x + self._forward(x)

    def _forward(self, x):
        return self.mamba(self.norm(x))

# Enable for training
model = MambaLM(use_checkpoint=True)

Memory savings: ~30-40% with minimal speed impact

Flash Attention Integration

Mamba's CUDA kernels already use flash-attention-style optimizations:

  • Fused operations in single kernel
  • Recomputation in backward pass
  • No intermediate activation storage

Long Context Training

Sequence Length Progression

# Start short, increase gradually
training_stages = [
    {'seq_len': 512,  'iters': 50000},
    {'seq_len': 1024, 'iters': 100000},
    {'seq_len': 2048, 'iters': 150000},
    {'seq_len': 4096, 'iters': 200000},
]

for stage in training_stages:
    dataloader = create_dataloader(seq_len=stage['seq_len'])
    train(model, dataloader, max_iters=stage['iters'])

Memory Requirements (Batch Size 1)

Sequence Length 130M Model 370M Model 1.4B Model
2K 4 GB 8 GB 24 GB
4K 5 GB 10 GB 32 GB
8K 6 GB 14 GB 48 GB
16K 8 GB 20 GB 64 GB
32K 12 GB 32 GB 96 GB

Mamba advantage: Memory grows linearly, Transformers grow quadratically

Common Training Issues

Issue: OOM during training

Solution 1: Reduce batch size

per_gpu_batch = 8  # Reduce from 16
gradient_accumulation = 8  # Increase from 4

Solution 2: Enable gradient checkpointing

model = MambaLM(use_checkpoint=True)

Solution 3: Use smaller sequence length

seq_len = 1024  # Reduce from 2048

Issue: Training unstable (loss spikes)

Solution 1: Check gradient norm

grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
print(f"Grad norm: {grad_norm}")  # Should be < 10

Solution 2: Lower learning rate

max_lr = 3e-4  # Reduce from 6e-4

Solution 3: Check Δ initialization

# Ensure dt_min, dt_max are reasonable
model = Mamba(
    d_model=512,
    dt_min=0.001,  # Not too small
    dt_max=0.1     # Not too large
)

Issue: Slow training speed

Solution 1: Verify CUDA kernels installed

import mamba_ssm
print(mamba_ssm.__version__)  # Should have CUDA kernels

Solution 2: Use BF16 on A100/H100

with autocast(dtype=torch.bfloat16):  # Faster than FP16
    loss = model(inputs)

Solution 3: Increase batch size if possible

per_gpu_batch = 16  # Increase from 8 (better GPU utilization)

Checkpointing

Save/Load Model

# Save
checkpoint = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'iter': iteration,
    'config': model_config
}
torch.save(checkpoint, f'checkpoint_{iteration}.pt')

# Load
checkpoint = torch.load('checkpoint_100000.pt')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
iteration = checkpoint['iter']

Best Practices

# Save every N iterations
if iteration % save_interval == 0:
    save_checkpoint(model, optimizer, iteration)

# Keep only last K checkpoints
checkpoints = sorted(glob.glob('checkpoint_*.pt'))
if len(checkpoints) > keep_last:
    for ckpt in checkpoints[:-keep_last]:
        os.remove(ckpt)

Resources