Skip to content

kevingil/diffusion-reasoning

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

40 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Diffusion Reasoning Model

A diffusion-based reasoning model using BERT for chain-of-thought generation with RL training (diffu-GRPO).

Architecture

Input Question
     ↓
[BERT-based Diffusion Model] ← RL Training (diffu-GRPO)
     ↓ (iterative denoising)
Complete Response (Reasoning + Answer)

Simplified diffusion-only architecture:

  • Base model: Small BERT (29M params, masked language model)
  • Single model generates both reasoning and final answer
  • No separate AR decoder or bridge layer needed
  • Training: Masked diffusion on reasoning+answer sequences
  • Inference: 20-step iterative denoising

Quick Start (5 Minutes)

1. Setup Environment

# Run the setup script
./SETUP.sh

# Activate environment
source .venv/bin/activate

# Verify everything works
python scripts/smoke_test.py

Expected output:

============================================================
✓ ALL TESTS PASSED!
============================================================

2. Your First Training Run

Train on a tiny subset to verify everything works (5-10 minutes):

python scripts/run_phase1.py \
  --config configs/tiny.yaml \
  --output_dir ./checkpoints/test \
  --no_wandb

You should see:

  1. Model loading (~30 seconds)
  2. Dataset preprocessing (~1 minute)
  3. Training progress bar
  4. Loss decreasing over time
  5. Checkpoints being saved

Installation

Local Development (M series Mac)

This project uses uv for fast dependency management.

# Install uv (if not already installed)
curl -LsSf https://astral.sh/uv/install.sh | sh

# Clone the repository
git clone <your-repo-url>
cd diffusion-reasoning

# Create virtual environment and install dependencies
uv sync

# Activate environment
source .venv/bin/activate

Note: Local development on M series Mac uses JAX with Metal backend. Training will be slower than GPU but suitable for development and debugging.

Google Colab (for actual training)

For actual training runs, use the provided Colab notebook:

  1. Open notebooks/colab_setup.ipynb in Google Colab
  2. Follow the setup cells to install dependencies with CUDA support
  3. Mount Google Drive for checkpoint persistence
  4. Run training

The Colab setup will use JAX with CUDA instead of Metal.

Features

  • Diffusion Model: Based on BERT (masked LM), generates reasoning and answer together
  • Iterative Denoising: 20-step denoising process
  • Data Pipeline: GSM8K with reasoning+answer format
  • Training Loop: Full SFT with checkpointing, evaluation, W&B integration
  • Inference: Progressive unmasking to reveal reasoning and answer
  • RL Training: GRPO trainer for optimizing reasoning quality

Project Structure

diffusion-reasoning/
├── configs/
│   └── tiny.yaml              ← Model and training config (BERT-based)
├── src/
│   ├── models/
│   │   └── diffusion_model.py ← Main model (BERT + diffusion)
│   ├── data/
│   │   ├── dataset.py         ← Data loading
│   │   └── masking.py         ← Masking logic
│   ├── training/
│   │   ├── train_sft.py       ← Training loop
│   │   └── utils.py           ← Checkpointing
│   └── rl/
│       ├── environment.py     ← Gym environment
│       └── grpo_trainer.py    ← GRPO trainer
├── scripts/
│   ├── run_phase1.py          ← Main entry point
│   └── smoke_test.py          ← Verification tests
└── notebooks/
    └── colab_setup.ipynb      ← Colab training notebook

Configuration

Editing configs/tiny.yaml

model:
  base_model: "google/bert_uncased_L-4_H-512_A-8"  # Small BERT
  vocab_size: 30522    # BERT vocab size
  d_model: 512         # BERT hidden size
  n_heads: 8
  n_layers: 4
  max_seq_len: 512     # Increased for reasoning space

training:
  batch_size: 8        # Batch size (adjust for GPU memory)
  learning_rate: 2e-5  # Lower for BERT fine-tuning
  max_steps: 10000     # Total training steps
  mask_ratio_min: 0.5  # Minimum masking ratio
  mask_ratio_max: 1.0  # Maximum masking ratio

Training

Masked Supervised Fine-Tuning

Quick Test (10 minutes on local)

python scripts/run_phase1.py \
  --config configs/tiny.yaml \
  --output_dir ./checkpoints/test \
  --no_wandb

Full Training (Colab recommended)

python scripts/run_phase1.py \
  --config configs/tiny.yaml \
  --output_dir ./checkpoints/phase1 \
  --wandb_project diffusion-reasoning

Resume from Checkpoint

python scripts/run_phase1.py \
  --config configs/tiny.yaml \
  --output_dir ./checkpoints/phase1 \
  --resume_from ./checkpoints/phase1

RL Training

Train with diffu-GRPO to optimize reasoning quality.

Checkpointing

Checkpoints are saved every 500 steps by default to output_dir:

  • checkpoint_N: Model state at step N
  • metadata_N.json: Training metrics and config

All checkpoints can be resumed from automatically.

Monitoring

With Weights & Biases

  1. Sign up at https://wandb.ai
  2. Run wandb login
  3. Train with --wandb_project your-project-name
  4. View metrics at https://wandb.ai/your-username/your-project-name

Without W&B

Add --no_wandb flag. Metrics will be printed to console.

Expected Performance

Training Time

Hardware 10K Steps Notes
M series Mac 8-10 hours Slow but good for debugging
Colab T4 4-6 hours Free tier
Colab A100 2-3 hours Pro+ recommended

Training Progress

After 100 Steps

  • Loss: ~8-10 → ~5-7
  • Model generating random tokens
  • Checkpoint saved

After 1,000 Steps

  • Loss: ~4-5
  • Model starting to form words
  • Some reasoning structure emerging

After 10,000 Steps

  • Loss: ~2-3
  • Model generates coherent reasoning
  • May not always be correct, but logical

Troubleshooting

Model not learning (loss not decreasing)

  • Check data preprocessing: Are prompts properly formatted?
  • Verify masking: Is reasoning+answer section being masked?
  • Check learning rate: May need adjustment for BERT

Out of Memory (OOM)

Reduce batch size in config:

training:
  batch_size: 4  # Reduce from 8
  gradient_accumulation: 8  # Increase to maintain effective batch size

Slow training on Mac

Normal - M series uses Metal backend which is slower than CUDA. Use for development only, switch to Colab for actual training.

Dataset download fails

If offline or behind firewall, download manually:

python -c "from datasets import load_dataset; load_dataset('gsm8k', 'main')"

Import errors

Make sure you're in the virtual environment:

source .venv/bin/activate

# Verify Python can find src/
python -c "from src.models import create_diffusion_model; print('OK')"

Development

Running Tests

Verify everything is working:

python scripts/smoke_test.py

This tests:

  • BERT-based model creation
  • Forward pass
  • Masking strategy
  • BERT tokenization
  • Data loading

Code Formatting

black src/ scripts/
ruff check src/ scripts/

Common Commands

# Activate environment
source .venv/bin/activate

# Run smoke test
python scripts/smoke_test.py

# Quick training test
python scripts/run_phase1.py --config configs/tiny.yaml --output_dir ./checkpoints/test --no_wandb

# Full training with W&B
python scripts/run_phase1.py --config configs/tiny.yaml --output_dir ./checkpoints/phase1 --wandb_project my-project

# Resume training
python scripts/run_phase1.py --config configs/tiny.yaml --output_dir ./checkpoints/phase1 --resume_from ./checkpoints/phase1

Next Steps

  1. Verify training: Train for 10K steps, check loss decreases
  2. Generate examples: Use iterative denoising to create reasoning traces
  3. RL optimization: Use GRPO training to improve reasoning quality
  4. Scale up: Try larger BERT models (bert-base)

Resources

  • IMPLEMENTATION_STATUS.md: Detailed implementation status
  • PLAN.md: Original design document and architecture details
  • notebooks/colab_setup.ipynb: Colab training notebook
  • SETUP.sh: Automated setup script

Pro Tips

  1. Start small: Use tiny config and small dataset first
  2. Monitor closely: Watch first 100 steps to catch issues early
  3. Save often: Default 500 steps is good, don't change
  4. Use W&B: Makes debugging much easier
  5. Test locally: Verify on Mac before expensive Colab runs

Support

For issues or questions:

  1. Run python scripts/smoke_test.py to diagnose issues
  2. Check error messages carefully - they're usually informative
  3. Review the relevant source file for the component that's failing
  4. Check IMPLEMENTATION_STATUS.md for what's implemented
  5. Open an issue with:
    • What you tried
    • What happened
    • Error messages
    • System info (OS, Python version, GPU)

Requirements

  • Python 3.14+
  • JAX with Metal (local) or CUDA (Colab)
  • 16GB+ RAM recommended
  • GPU highly recommended for training

License

MIT


Ready to start? Run ./SETUP.sh && source .venv/bin/activate && python scripts/smoke_test.py 🚀

About

Diffusion Reasoning LLM

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors