10 KiB
RWKV-7: Latest Improvements (March 2025)
Overview
RWKV-7 is the latest version released in March 2025, introducing multimodal capabilities and improved scaling to 14B+ parameters.
Paper: https://arxiv.org/abs/2503.14456 (March 2025)
Key Improvements Over RWKV-6
1. Enhanced Numerical Stability
Problem in RWKV-6:
# Exponential operations could overflow for large models
att_aa = exp(w) * att_aa + k * v # Overflow risk!
RWKV-7 Solution:
# Log-space computation with safe exponentiation
log_att_aa = log_softmax([log(k * v), log_w + log(att_aa)])
att_aa = exp(log_att_aa)
Result: Stable training up to 14B parameters (RWKV-6 struggled beyond 7B)
2. Improved Time-Decay Initialization
RWKV-6:
# Simple logarithmic spacing
time_decay[i] = -5.0 + 8.0 * (i / n_layers)
RWKV-7:
# Adaptive per-head decay with better range
for layer in range(n_layers):
for head in range(n_heads):
# Different heads specialize in different timescales
alpha = (layer / n_layers) ** 0.7 # Non-linear progression
beta = (head / n_heads) * 0.5
time_decay[layer, head] = -6.0 + 9.0 * alpha + beta
# Result: Better long/short-term memory balance
Impact: 15-20% perplexity improvement on long-context tasks
3. Multi-Head Time-Mixing Refinements
RWKV-6 Multi-Head:
# Simple concatenation
heads = [head_i(x) for head_i in heads]
output = concat(heads)
RWKV-7 Multi-Head:
# Attention-style output projection
heads = [head_i(x) for head_i in heads]
concat_heads = concat(heads)
output = output_proj(concat_heads) # Learnable mixing
# Plus: Per-head layer norm
for i, head in enumerate(heads):
heads[i] = head_norm[i](head) # Separate norm per head
Result: Better head specialization, 8-12% quality improvement
4. Rotary Position Encoding (RoPE) Integration
New in RWKV-7:
class RWKV7_TimeMix(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.rope = RotaryEmbedding(d_model // n_heads)
def forward(self, x):
k = self.key(x) # (B, T, d_model)
v = self.value(x)
# Apply RoPE to keys
k = self.rope.rotate_queries_or_keys(k)
# WKV with position-aware keys
wkv = self.wkv(k, v)
return wkv
Why useful: Improves positional awareness without breaking O(n) complexity
5. RWKV-7 Block Structure
class RWKV7_Block(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
# Multi-head time-mixing with RoPE
self.att = RWKV7_MultiHeadTimeMix(d_model, n_heads)
# Enhanced channel-mixing
self.ffn = RWKV7_ChannelMix(d_model, hidden_ratio=3.5) # Larger FFN
def forward(self, x, state):
# Pre-norm (like GPT)
att_out, new_state = self.att(self.ln1(x), state)
x = x + att_out
# FFN with gating
ffn_out = self.ffn(self.ln2(x))
x = x + ffn_out
return x, new_state
Multimodal Capabilities
Vision Encoder Integration
Architecture:
class RWKV7_Multimodal(nn.Module):
def __init__(self):
super().__init__()
# Vision encoder (CLIP-style)
self.vision_encoder = VisionTransformer(
patch_size=14,
d_model=1024,
n_layers=24
)
# Projection to RWKV space
self.vision_proj = nn.Linear(1024, d_model)
# RWKV language model
self.rwkv = RWKV7_LanguageModel(d_model=2560, n_layers=40)
def forward(self, image, text, state=None):
# Encode image to patches
vision_tokens = self.vision_encoder(image) # (B, 256, 1024)
vision_tokens = self.vision_proj(vision_tokens) # (B, 256, 2560)
# Concatenate vision and text tokens
combined = torch.cat([vision_tokens, text], dim=1)
# Process with RWKV
out, state = self.rwkv(combined, state)
return out, state
Vision-Language Tasks
Image Captioning:
model = RWKV7_Multimodal()
# Encode image
image = load_image('cat.jpg')
vision_tokens = model.vision_encoder(image)
# Generate caption
state = None
_, state = model.rwkv(vision_tokens, state) # Process image
# Autoregressive caption generation
caption = []
for _ in range(max_length):
logits, state = model.rwkv(prev_token, state)
next_token = sample(logits)
caption.append(next_token)
VQA (Visual Question Answering):
# Question: "What color is the cat?"
question_tokens = tokenizer.encode("What color is the cat?")
# Process image + question
combined = torch.cat([vision_tokens, question_tokens], dim=1)
answer_logits, state = model.rwkv(combined, state)
# Answer: "orange"
Training Multimodal RWKV-7
# Pretrain vision encoder (CLIP-style)
train_vision_encoder(image_text_pairs)
# Freeze vision encoder
model.vision_encoder.requires_grad_(False)
# Train projection + RWKV
for batch in multimodal_dataloader:
images, captions = batch
# Forward
vision_tokens = model.vision_encoder(images)
vision_tokens = model.vision_proj(vision_tokens)
logits, _ = model.rwkv(
torch.cat([vision_tokens, captions[:, :-1]], dim=1),
state=None
)
# Loss (next token prediction)
loss = F.cross_entropy(
logits[:, vision_tokens.shape[1]:].reshape(-1, vocab_size),
captions.reshape(-1)
)
loss.backward()
optimizer.step()
Scaling to 14B Parameters
Model Configuration
| Model | Layers | d_model | n_heads | Params | Context | VRAM (FP16) |
|---|---|---|---|---|---|---|
| RWKV-7-1.5B | 24 | 2048 | 16 | 1.5B | Infinite | 3 GB |
| RWKV-7-3B | 32 | 2560 | 20 | 3B | Infinite | 6 GB |
| RWKV-7-7B | 32 | 4096 | 32 | 7B | Infinite | 14 GB |
| RWKV-7-14B | 40 | 5120 | 40 | 14B | Infinite | 28 GB |
Training Efficiency Improvements
RWKV-6 Training (7B):
- Speed: 45K tokens/sec (8× A100)
- Memory: 38 GB per GPU (4K sequence)
- Stability: Occasional loss spikes
RWKV-7 Training (14B):
- Speed: 52K tokens/sec (8× A100) - 15% faster
- Memory: 42 GB per GPU (4K sequence) - Better utilization
- Stability: No loss spikes - Improved stability
Key optimization: Fused CUDA kernels for multi-head WKV
RWKV-7 vs GPT-3 (14B)
| Metric | RWKV-7-14B | GPT-3-13B | Advantage |
|---|---|---|---|
| Training Speed | 52K tok/s | 28K tok/s | 1.9× |
| Inference (2K ctx) | 6,100 tok/s | 1,800 tok/s | 3.4× |
| Inference (8K ctx) | 5,800 tok/s | 450 tok/s | 12.9× |
| Memory (inference) | 28 GB | 52 GB | 1.9× |
| Perplexity (Pile) | 6.8 | 7.2 | +6% |
Production Use Cases
Microsoft Integration
Windows Copilot (Limited Release):
- Uses RWKV-7-3B for on-device inference
- 5-8× faster than GPT-2 with better quality
- Constant memory for infinite context
Office 365 (Experimental):
- Document summarization with RWKV-7-7B
- Handles 100K+ token documents efficiently
- No KV cache storage needed
NVIDIA NeMo Support
NeMo Guardrails with RWKV-7:
from nemoguardrails import RailsConfig
from nemoguardrails.llm.providers import register_llm_provider
# Register RWKV-7 as LLM backend
register_llm_provider("rwkv7", RWKV7Provider)
config = RailsConfig.from_path("config/")
rails = LLMRails(config, llm_provider="rwkv7")
# Use for content moderation
response = rails.generate(user_input="...")
Benchmarks (RWKV-7 vs RWKV-6)
Language Modeling
| Dataset | RWKV-6-7B | RWKV-7-7B | Improvement |
|---|---|---|---|
| Pile (val) | 7.8 | 7.1 | +9% |
| C4 | 9.3 | 8.6 | +8% |
| WikiText-103 | 8.4 | 7.7 | +8% |
| Lambada | 11.2 | 9.8 | +13% |
Long-Context Tasks (32K context)
| Task | RWKV-6-7B | RWKV-7-7B | Improvement |
|---|---|---|---|
| QuALITY | 52.3 | 61.8 | +18% |
| Qasper | 38.1 | 46.7 | +23% |
| NarrativeQA | 41.2 | 49.5 | +20% |
RWKV-7's improved time-decay significantly helps long-context understanding
Multimodal Benchmarks
| Task | RWKV-7-7B | LLaVA-7B | BLIP-2-7B |
|---|---|---|---|
| VQAv2 | 74.2 | 78.5 | 82.1 |
| GQA | 58.3 | 62.1 | 65.4 |
| TextVQA | 51.2 | 58.2 | 60.8 |
| COCO Caption | 118.3 | 125.7 | 132.4 |
Note: RWKV-7 competitive but not SOTA on vision (vision-focused models still better)
Migration from RWKV-6 to RWKV-7
Model Conversion
# Load RWKV-6 checkpoint
rwkv6_state = torch.load('rwkv6-7b.pth')
# Initialize RWKV-7 model
rwkv7_model = RWKV7_Model(d_model=4096, n_layers=32, n_heads=32)
# Convert weights (mostly compatible)
for key in rwkv6_state:
if 'time_mixing' in key:
# RWKV-7 uses multi-head, need to split
rwkv7_key = convert_key_to_multihead(key)
rwkv7_model.state_dict()[rwkv7_key].copy_(rwkv6_state[key])
else:
# Direct copy
rwkv7_model.state_dict()[key].copy_(rwkv6_state[key])
# Fine-tune on small dataset to adapt
finetune(rwkv7_model, small_dataset, epochs=1)
State Compatibility
RWKV-6 State:
state_v6 = (att_aa, att_ab, att_x_prev, ffn_x_prev) # 4 components
RWKV-7 State (Multi-head):
state_v7 = (
att_aa_heads, # (n_heads, d_model//n_heads)
att_ab_heads, # (n_heads, d_model//n_heads)
att_x_prev,
ffn_x_prev
) # 4 components, but att_* are multi-head
Conversion:
# Split RWKV-6 state into RWKV-7 multi-head state
def convert_state_v6_to_v7(state_v6, n_heads):
att_aa, att_ab, att_x_prev, ffn_x_prev = state_v6
d_head = att_aa.shape[-1] // n_heads
att_aa_heads = att_aa.view(-1, n_heads, d_head).transpose(0, 1)
att_ab_heads = att_ab.view(-1, n_heads, d_head).transpose(0, 1)
return (att_aa_heads, att_ab_heads, att_x_prev, ffn_x_prev)
Resources
- Paper: https://arxiv.org/abs/2503.14456 (RWKV-7, March 2025)
- GitHub: https://github.com/BlinkDL/RWKV-LM (v7 branch)
- Models: https://huggingface.co/BlinkDL/rwkv-7-world
- Multimodal Demo: https://huggingface.co/spaces/BlinkDL/RWKV-7-Multimodal
- Discord: https://discord.gg/bDSBUMeFpc
- Wiki: https://wiki.rwkv.com/rwkv7