Skip to content

miccunifi/DocWaveDiff

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DocWaveDiff: A Predict-and-Refine approach for Document Image Enhancement with Wavelet U-Nets and Diffusion models

MICC Logo      DINFO Logo      WACV Logo

Overview

DocWaveDiff is a novel document restoration method based on a predict-and-refine diffusion framework incorporating wavelet U-Nets. OCR and document layout analysis algorithms are essential components of AI-based document systems, yet they are typically trained on clean, degradation-free images. When applied to degraded documents such as blurred scans or pages spoiled by handwritten text, their performance drops significantly.

Our approach addresses this issue through a two-stage restoration process:

  1. Early Predictor: Generates an initial restoration from a degraded image patch and optionally its prior features
  2. Denoiser Refiner: Estimates the residual image to refine the initial prediction

The combination of these outputs yields the final restored result.

DocWaveDiff Demo

Key Features

  • 🌊 Wavelet Transform Integration: Incorporates Discrete Wavelet Transform (DWT) and Inverse Wavelet Transform (IWT) into U-Net architecture
  • 🔄 Predict-and-Refine Framework: Two-stage approach combining early prediction with diffusion-based refinement
  • 📄 Multiple Degradation Types: Handles deblurring, handwriting removal, and various document degradation scenarios

Installation

1. Clone the Repository

git clone https://github.com/yourusername/DocWaveDiff.git
cd DocWaveDiff

2. Create Conda Environment

Create a new conda environment named DocWaveDiff with Python 3.12:

conda create -n DocWaveDiff python=3.12
conda activate DocWaveDiff

3. Install Dependencies

Install the required packages:

pip install -r requirements.txt

Project Structure

DocWaveDiff/
├── src/                              # Source code
│   ├── unet.py                       # Wavelet U-Net architecture
│   ├── wavelets.py                   # DWT/IWT implementations (Haar wavelet)
│   ├── attentions.py                 # Attention block mechanisms
│   ├── blocks.py                     # Building blocks (DownBlock, UpBlock, MiddleBlock)
│   ├── swish.py                      # Swish activation function
│   ├── docwavediff_deblurring.py     # DocWaveDiff model for deblurring task
│   ├── docwavediff_inpanting.py      # DocWaveDiff model for inpainting/removal task
│   ├── trainer_deblurring.py         # Training logic for deblurring
│   ├── trainer_inpainting.py         # Training logic for inpainting
│   ├── train.py                      # Main training script
│   ├── ema.py                        # Exponential Moving Average for model weights
│   ├── sobel.py                      # Sobel edge detection utilities
│   ├── config.py                     # Configuration parsing
│   └── utils.py                      # Utility functions
├── confs/                            # Configuration files
│   ├── conf_document_deblurring.yml  # Config for document deblurring task
│   └── conf_document_inpainting.yml  # Config for handwriting removal task
├── data/                             # Data loading utilities
│   ├── docdata.py                    # Dataset loader for deblurring
│   └── docdata_inpainting.py         # Dataset loader for inpainting with masks
├── datasets/                         # Place your datasets here
│   ├── deblurring/
│   │   ├── train/
│   │   │   ├── degraded/
│   │   │   └── ground_truth/
│   │   └── test/
│   └── inpainting/
│       ├── train/
│       │   ├── degraded/
│       │   ├── ground_truth/
│       │   └── masks/
│       └── test/
├── checksave/                        # Model checkpoints saved here
│   ├── deblurring/
│   └── inpainting/
├── results/                          # Inference results saved here
│   ├── deblurring/
│   └── inpainting/
├── Training/                         # Training visualizations (on-the-fly restoration examples)
│   ├── deblurring/
│   └── inpainting/
├── schedule/                         # Diffusion scheduling utilities
├── utils/                            # Additional utilities
├── repo_images/                      # Images for README
│   ├── gif.gif                       # Demo visualization
│   ├── logoMICC_white.png            # MICC Lab logo
│   ├── DINFO_bianco.png              # DINFO Department logo
│   └── wacv_logo.png                 # WACV Conference logo
├── main.py                           # Main entry point
├── requirements.txt                  # Python dependencies
└── README.md                         # This file

Usage

Training

Document Deblurring

To train the model for document deblurring:

python main.py --config confs/conf_document_deblurring.yml

Edit confs/conf_document_deblurring.yml to customize:

  • Dataset paths
  • Model hyperparameters
  • Training settings (epochs, batch size, learning rate)
  • Checkpoint save frequency

Handwriting Removal (Inpainting)

To train the model for handwriting removal:

python main.py --config confs/conf_document_inpainting.yml

Edit confs/conf_document_inpainting.yml for task-specific configurations.

During training:

  • Model weights are saved in checksave/
  • Training visualizations (restoration examples on training set) are saved in Training/
  • Training progress and losses are logged to console

Inference

To restore degraded documents using a trained model, change the MODE flag to 0 (test) in each .yaml file.

Restored images will be saved in the results/ directory.

Dataset Preparation

Organize your dataset in the following structure:

For Deblurring:

datasets/deblurring/
├── train/
│   ├── degraded/          # Blurred document images
│   └── ground_truth/      # Clean ground truth images
└── test/
    ├── degraded/
    └── ground_truth/

For Handwriting Removal:

datasets/inpainting/
├── train/
│   ├── degraded/          # Images with handwriting
│   ├── ground_truth/      # Clean images without handwriting
│   └── masks/             # Binary masks indicating handwriting regions
└── test/
    ├── degraded/
    ├── ground_truth/
    └── masks/

Ensure that corresponding images have the same filename across folders.

Architecture

DocWaveDiff employs a Wavelet U-Net architecture that integrates wavelet transforms at multiple scales:

  • Encoder: Progressively downsamples features using DWT, capturing multi-scale information in frequency domain
  • Middle Block: Processes the bottleneck features with attention mechanisms
  • Decoder: Reconstructs the image using IWT and skip connections from encoder

The predict-and-refine framework consists of:

  1. Early Predictor U-Net (is_noise=False): Generates initial restoration without time conditioning
  2. Diffusion Refiner U-Net (is_noise=True): Refines prediction through iterative denoising with time embeddings

Wavelet Transforms

The model uses Haar wavelet decomposition:

  • DWT (Discrete Wavelet Transform): Decomposes image into:

    • LL (Low-Low): Approximation coefficients
    • LH (Low-High): Horizontal detail
    • HL (High-Low): Vertical detail
    • HH (High-High): Diagonal detail
  • IWT (Inverse Wavelet Transform): Reconstructs image from wavelet subbands

Time Embedding

For the diffusion refiner, sinusoidal positional encodings are used to create time embeddings that condition the denoising process at each diffusion timestep.

Model Components

  • unet.py: Wavelet-integrated U-Net with configurable encoder/decoder depths
  • wavelets.py: Efficient DWT/IWT implementations for multi-scale feature extraction
  • attentions.py: Multi-head self-attention for capturing long-range dependencies
  • blocks.py: Residual blocks with time conditioning and group normalization
  • docwavediff_deblurring.py & docwavediff_inpanting.py: Complete models combining Early Predictor and Denoiser Refiner
  • ema.py: Exponential Moving Average for stable training

Citation

If you find this work useful in your research, please consider citing:

@inproceedings{docwavediff2025,
  title={DocWaveDiff: DocWaveDiff: A Predict-and-Refine approach for Document Image
Enhancement with Wavelet U-Nets and Diffusion models},
  author={Matteo Marulli, Marco Bertini},
  booktitle={IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)},
  year={2025}
}

Acknowledgments

This work was conducted at the Media Integration and Communication Center (MICC) and the Department of Information Engineering (DINFO), University of Florence.

License

coming soon!

Contact

For questions or issues, please open an issue on GitHub or contact the authors: matteomarulli@unifi.it, marco.bertini@unifi.it.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages