A diffusion-based reasoning model using BERT for chain-of-thought generation with RL training (diffu-GRPO).
Input Question
↓
[BERT-based Diffusion Model] ← RL Training (diffu-GRPO)
↓ (iterative denoising)
Complete Response (Reasoning + Answer)
Simplified diffusion-only architecture:
- Base model: Small BERT (29M params, masked language model)
- Single model generates both reasoning and final answer
- No separate AR decoder or bridge layer needed
- Training: Masked diffusion on reasoning+answer sequences
- Inference: 20-step iterative denoising
# Run the setup script
./SETUP.sh
# Activate environment
source .venv/bin/activate
# Verify everything works
python scripts/smoke_test.pyExpected output:
============================================================
✓ ALL TESTS PASSED!
============================================================
Train on a tiny subset to verify everything works (5-10 minutes):
python scripts/run_phase1.py \
--config configs/tiny.yaml \
--output_dir ./checkpoints/test \
--no_wandbYou should see:
- Model loading (~30 seconds)
- Dataset preprocessing (~1 minute)
- Training progress bar
- Loss decreasing over time
- Checkpoints being saved
This project uses uv for fast dependency management.
# Install uv (if not already installed)
curl -LsSf https://astral.sh/uv/install.sh | sh
# Clone the repository
git clone <your-repo-url>
cd diffusion-reasoning
# Create virtual environment and install dependencies
uv sync
# Activate environment
source .venv/bin/activateNote: Local development on M series Mac uses JAX with Metal backend. Training will be slower than GPU but suitable for development and debugging.
For actual training runs, use the provided Colab notebook:
- Open
notebooks/colab_setup.ipynbin Google Colab - Follow the setup cells to install dependencies with CUDA support
- Mount Google Drive for checkpoint persistence
- Run training
The Colab setup will use JAX with CUDA instead of Metal.
- Diffusion Model: Based on BERT (masked LM), generates reasoning and answer together
- Iterative Denoising: 20-step denoising process
- Data Pipeline: GSM8K with reasoning+answer format
- Training Loop: Full SFT with checkpointing, evaluation, W&B integration
- Inference: Progressive unmasking to reveal reasoning and answer
- RL Training: GRPO trainer for optimizing reasoning quality
diffusion-reasoning/
├── configs/
│ └── tiny.yaml ← Model and training config (BERT-based)
├── src/
│ ├── models/
│ │ └── diffusion_model.py ← Main model (BERT + diffusion)
│ ├── data/
│ │ ├── dataset.py ← Data loading
│ │ └── masking.py ← Masking logic
│ ├── training/
│ │ ├── train_sft.py ← Training loop
│ │ └── utils.py ← Checkpointing
│ └── rl/
│ ├── environment.py ← Gym environment
│ └── grpo_trainer.py ← GRPO trainer
├── scripts/
│ ├── run_phase1.py ← Main entry point
│ └── smoke_test.py ← Verification tests
└── notebooks/
└── colab_setup.ipynb ← Colab training notebook
model:
base_model: "google/bert_uncased_L-4_H-512_A-8" # Small BERT
vocab_size: 30522 # BERT vocab size
d_model: 512 # BERT hidden size
n_heads: 8
n_layers: 4
max_seq_len: 512 # Increased for reasoning space
training:
batch_size: 8 # Batch size (adjust for GPU memory)
learning_rate: 2e-5 # Lower for BERT fine-tuning
max_steps: 10000 # Total training steps
mask_ratio_min: 0.5 # Minimum masking ratio
mask_ratio_max: 1.0 # Maximum masking ratiopython scripts/run_phase1.py \
--config configs/tiny.yaml \
--output_dir ./checkpoints/test \
--no_wandbpython scripts/run_phase1.py \
--config configs/tiny.yaml \
--output_dir ./checkpoints/phase1 \
--wandb_project diffusion-reasoningpython scripts/run_phase1.py \
--config configs/tiny.yaml \
--output_dir ./checkpoints/phase1 \
--resume_from ./checkpoints/phase1Train with diffu-GRPO to optimize reasoning quality.
Checkpoints are saved every 500 steps by default to output_dir:
checkpoint_N: Model state at step Nmetadata_N.json: Training metrics and config
All checkpoints can be resumed from automatically.
- Sign up at https://wandb.ai
- Run
wandb login - Train with
--wandb_project your-project-name - View metrics at https://wandb.ai/your-username/your-project-name
Add --no_wandb flag. Metrics will be printed to console.
| Hardware | 10K Steps | Notes |
|---|---|---|
| M series Mac | 8-10 hours | Slow but good for debugging |
| Colab T4 | 4-6 hours | Free tier |
| Colab A100 | 2-3 hours | Pro+ recommended |
- Loss: ~8-10 → ~5-7
- Model generating random tokens
- Checkpoint saved
- Loss: ~4-5
- Model starting to form words
- Some reasoning structure emerging
- Loss: ~2-3
- Model generates coherent reasoning
- May not always be correct, but logical
- Check data preprocessing: Are prompts properly formatted?
- Verify masking: Is reasoning+answer section being masked?
- Check learning rate: May need adjustment for BERT
Reduce batch size in config:
training:
batch_size: 4 # Reduce from 8
gradient_accumulation: 8 # Increase to maintain effective batch sizeNormal - M series uses Metal backend which is slower than CUDA. Use for development only, switch to Colab for actual training.
If offline or behind firewall, download manually:
python -c "from datasets import load_dataset; load_dataset('gsm8k', 'main')"Make sure you're in the virtual environment:
source .venv/bin/activate
# Verify Python can find src/
python -c "from src.models import create_diffusion_model; print('OK')"Verify everything is working:
python scripts/smoke_test.pyThis tests:
- BERT-based model creation
- Forward pass
- Masking strategy
- BERT tokenization
- Data loading
black src/ scripts/
ruff check src/ scripts/# Activate environment
source .venv/bin/activate
# Run smoke test
python scripts/smoke_test.py
# Quick training test
python scripts/run_phase1.py --config configs/tiny.yaml --output_dir ./checkpoints/test --no_wandb
# Full training with W&B
python scripts/run_phase1.py --config configs/tiny.yaml --output_dir ./checkpoints/phase1 --wandb_project my-project
# Resume training
python scripts/run_phase1.py --config configs/tiny.yaml --output_dir ./checkpoints/phase1 --resume_from ./checkpoints/phase1- Verify training: Train for 10K steps, check loss decreases
- Generate examples: Use iterative denoising to create reasoning traces
- RL optimization: Use GRPO training to improve reasoning quality
- Scale up: Try larger BERT models (bert-base)
- IMPLEMENTATION_STATUS.md: Detailed implementation status
- PLAN.md: Original design document and architecture details
- notebooks/colab_setup.ipynb: Colab training notebook
- SETUP.sh: Automated setup script
- Start small: Use tiny config and small dataset first
- Monitor closely: Watch first 100 steps to catch issues early
- Save often: Default 500 steps is good, don't change
- Use W&B: Makes debugging much easier
- Test locally: Verify on Mac before expensive Colab runs
For issues or questions:
- Run
python scripts/smoke_test.pyto diagnose issues - Check error messages carefully - they're usually informative
- Review the relevant source file for the component that's failing
- Check IMPLEMENTATION_STATUS.md for what's implemented
- Open an issue with:
- What you tried
- What happened
- Error messages
- System info (OS, Python version, GPU)
- Python 3.14+
- JAX with Metal (local) or CUDA (Colab)
- 16GB+ RAM recommended
- GPU highly recommended for training
MIT
Ready to start? Run ./SETUP.sh && source .venv/bin/activate && python scripts/smoke_test.py 🚀