DocWaveDiff: A Predict-and-Refine approach for Document Image Enhancement with Wavelet U-Nets and Diffusion models
Authors: Matteo Marulli, Marco Bertini
DocWaveDiff is a novel document restoration method based on a predict-and-refine diffusion framework incorporating wavelet U-Nets. OCR and document layout analysis algorithms are essential components of AI-based document systems, yet they are typically trained on clean, degradation-free images. When applied to degraded documents such as blurred scans or pages spoiled by handwritten text, their performance drops significantly.
Our approach addresses this issue through a two-stage restoration process:
- Early Predictor: Generates an initial restoration from a degraded image patch and optionally its prior features
- Denoiser Refiner: Estimates the residual image to refine the initial prediction
The combination of these outputs yields the final restored result.
- 🌊 Wavelet Transform Integration: Incorporates Discrete Wavelet Transform (DWT) and Inverse Wavelet Transform (IWT) into U-Net architecture
- 🔄 Predict-and-Refine Framework: Two-stage approach combining early prediction with diffusion-based refinement
- 📄 Multiple Degradation Types: Handles deblurring, handwriting removal, and various document degradation scenarios
git clone https://github.com/yourusername/DocWaveDiff.git
cd DocWaveDiffCreate a new conda environment named DocWaveDiff with Python 3.12:
conda create -n DocWaveDiff python=3.12
conda activate DocWaveDiffInstall the required packages:
pip install -r requirements.txtDocWaveDiff/
├── src/ # Source code
│ ├── unet.py # Wavelet U-Net architecture
│ ├── wavelets.py # DWT/IWT implementations (Haar wavelet)
│ ├── attentions.py # Attention block mechanisms
│ ├── blocks.py # Building blocks (DownBlock, UpBlock, MiddleBlock)
│ ├── swish.py # Swish activation function
│ ├── docwavediff_deblurring.py # DocWaveDiff model for deblurring task
│ ├── docwavediff_inpanting.py # DocWaveDiff model for inpainting/removal task
│ ├── trainer_deblurring.py # Training logic for deblurring
│ ├── trainer_inpainting.py # Training logic for inpainting
│ ├── train.py # Main training script
│ ├── ema.py # Exponential Moving Average for model weights
│ ├── sobel.py # Sobel edge detection utilities
│ ├── config.py # Configuration parsing
│ └── utils.py # Utility functions
├── confs/ # Configuration files
│ ├── conf_document_deblurring.yml # Config for document deblurring task
│ └── conf_document_inpainting.yml # Config for handwriting removal task
├── data/ # Data loading utilities
│ ├── docdata.py # Dataset loader for deblurring
│ └── docdata_inpainting.py # Dataset loader for inpainting with masks
├── datasets/ # Place your datasets here
│ ├── deblurring/
│ │ ├── train/
│ │ │ ├── degraded/
│ │ │ └── ground_truth/
│ │ └── test/
│ └── inpainting/
│ ├── train/
│ │ ├── degraded/
│ │ ├── ground_truth/
│ │ └── masks/
│ └── test/
├── checksave/ # Model checkpoints saved here
│ ├── deblurring/
│ └── inpainting/
├── results/ # Inference results saved here
│ ├── deblurring/
│ └── inpainting/
├── Training/ # Training visualizations (on-the-fly restoration examples)
│ ├── deblurring/
│ └── inpainting/
├── schedule/ # Diffusion scheduling utilities
├── utils/ # Additional utilities
├── repo_images/ # Images for README
│ ├── gif.gif # Demo visualization
│ ├── logoMICC_white.png # MICC Lab logo
│ ├── DINFO_bianco.png # DINFO Department logo
│ └── wacv_logo.png # WACV Conference logo
├── main.py # Main entry point
├── requirements.txt # Python dependencies
└── README.md # This file
To train the model for document deblurring:
python main.py --config confs/conf_document_deblurring.ymlEdit confs/conf_document_deblurring.yml to customize:
- Dataset paths
- Model hyperparameters
- Training settings (epochs, batch size, learning rate)
- Checkpoint save frequency
To train the model for handwriting removal:
python main.py --config confs/conf_document_inpainting.ymlEdit confs/conf_document_inpainting.yml for task-specific configurations.
During training:
- Model weights are saved in
checksave/ - Training visualizations (restoration examples on training set) are saved in
Training/ - Training progress and losses are logged to console
To restore degraded documents using a trained model, change the MODE flag to 0 (test) in each .yaml file.
Restored images will be saved in the results/ directory.
Organize your dataset in the following structure:
For Deblurring:
datasets/deblurring/
├── train/
│ ├── degraded/ # Blurred document images
│ └── ground_truth/ # Clean ground truth images
└── test/
├── degraded/
└── ground_truth/
For Handwriting Removal:
datasets/inpainting/
├── train/
│ ├── degraded/ # Images with handwriting
│ ├── ground_truth/ # Clean images without handwriting
│ └── masks/ # Binary masks indicating handwriting regions
└── test/
├── degraded/
├── ground_truth/
└── masks/
Ensure that corresponding images have the same filename across folders.
DocWaveDiff employs a Wavelet U-Net architecture that integrates wavelet transforms at multiple scales:
- Encoder: Progressively downsamples features using DWT, capturing multi-scale information in frequency domain
- Middle Block: Processes the bottleneck features with attention mechanisms
- Decoder: Reconstructs the image using IWT and skip connections from encoder
The predict-and-refine framework consists of:
- Early Predictor U-Net (
is_noise=False): Generates initial restoration without time conditioning - Diffusion Refiner U-Net (
is_noise=True): Refines prediction through iterative denoising with time embeddings
The model uses Haar wavelet decomposition:
-
DWT (Discrete Wavelet Transform): Decomposes image into:
- LL (Low-Low): Approximation coefficients
- LH (Low-High): Horizontal detail
- HL (High-Low): Vertical detail
- HH (High-High): Diagonal detail
-
IWT (Inverse Wavelet Transform): Reconstructs image from wavelet subbands
For the diffusion refiner, sinusoidal positional encodings are used to create time embeddings that condition the denoising process at each diffusion timestep.
unet.py: Wavelet-integrated U-Net with configurable encoder/decoder depthswavelets.py: Efficient DWT/IWT implementations for multi-scale feature extractionattentions.py: Multi-head self-attention for capturing long-range dependenciesblocks.py: Residual blocks with time conditioning and group normalizationdocwavediff_deblurring.py&docwavediff_inpanting.py: Complete models combining Early Predictor and Denoiser Refinerema.py: Exponential Moving Average for stable training
If you find this work useful in your research, please consider citing:
@inproceedings{docwavediff2025,
title={DocWaveDiff: DocWaveDiff: A Predict-and-Refine approach for Document Image
Enhancement with Wavelet U-Nets and Diffusion models},
author={Matteo Marulli, Marco Bertini},
booktitle={IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)},
year={2025}
}This work was conducted at the Media Integration and Communication Center (MICC) and the Department of Information Engineering (DINFO), University of Florence.
coming soon!
For questions or issues, please open an issue on GitHub or contact the authors: matteomarulli@unifi.it, marco.bertini@unifi.it.



