Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 0 additions & 43 deletions .pre-commit-config.yaml

This file was deleted.

31 changes: 0 additions & 31 deletions CONTRIBUTING.md

This file was deleted.

46 changes: 27 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,20 +1,29 @@
# Generative SSL

This is the PyTorch implemention of our paper **"Can Generative Models Improve Self-Supervised Representation Learning?"** submitted to ECCV 2024 for reproducing the experiments.
This repository contains the PyTorch implementation of **"Can Generative Models Improve Self-Supervised Representation Learning?"** accepted to AAAI 2025.

## Abstract

Self-supervised learning (SSL) holds significant promise in leveraging unlabeled data for learning robust visual representations. However, the limited diversity and quality of existing augmentation techniques constrain SSL performance. We introduce a novel framework that incorporates generative models to produce semantically consistent and diverse augmentations conditioned on source images. This approach enriches SSL training, improving downstream task performance by up to 10\% in Top-1 accuracy across various techniques.

![avatar](./images/GenSSL_last-main.png)
Our augmentation pipeline utilizes generative models, i.e., Stable Diffusion or ICGAN, conditioned on the source image representation, accompanied by the standard SSL augmentations. The components inside the Generative Augmentation module, i.e. the pretrained SSL encoder and the generative model remain frozen throughout the SSL training process.


## Requirements

To create the virtual environment for running the experiments, you need to run:
We used solo-learn library for the implementation of SSL method. You can find the library in this [LINK](https://github.com/vturrisi/solo-learn).

`pip install -r requirements.txt`
To create the virtual environment for running the experiments please first:

**Note:**
**You always need to set the proper path to the virtual environment, the dataset and the model in each SLURM file before submitting the job. Here are the options for the datasets and models that we used in our experiments:**
`cd solo-learn`

Then install requirements based on solo-learn library documentation [here](https://github.com/vturrisi/solo-learn?tab=readme-ov-file#installation).

- **Datasets:** ImageNet, iNaturalist2018, Food101, Places365, CIFAR10/100
- **Models:** Baseline (SimSiam model trained on ImageNet), SimSiam model trained with ICGAN augmentations, SimSiam model trained with Stable Diffusion augmentations

## Data Generation
**Note:**
**You always need to set the proper path to the virtual environment and path to save generated data in generation scripts.**

To generate augmentations with ICGAN run:

Expand All @@ -24,24 +33,23 @@ To generate augmentations with Stable Diffusion run:

`sbatch GenerativeSSL/scripts/generation_scripts/gen_img_stablediff.slrm`

## Training

To train the SimSiam method on the ImageNet, run:

`sbatch GenerativeSSL/scripts/train_scrpits/train_simsiam_singlenode.slrm`
## Training and Evaluation
**Note:**
**You always need to set the proper path to the virtual environment in solo-learn slrm files. We pretrained our models on train split of Imagenet. Here is the model and dataset choices for the evaluation that we used in our experiments:**

In this file, there is a `use_synthetic_data` flag that you can use to train the model with augmentations. You just need to specify the path to synthetic data. (Either ICGAN or Stable Diffusion augmentations) By default, the `use_synthetic_data` flag has been passed in the SLURM file.
- **Datasets:** ImageNet, iNaturalist2018, Food101, Places365, CIFAR10/100
- **Models:** SimCLR (Baseline, ICGAN, Stablediff), SimSiam (Baseline, ICGAN, Stablediff), MoCo (Baseline, ICGAN, Stablediff), BYOL (Baseline, ICGAN, Stablediff), Barlow Twins (Baseline, ICGAN, Stablediff)

## Evaluation
### Training

For downstream tasks, there are all evaluation scripts in this `GenerativeSSL/scripts/eval_scripts` folder. In each dataset folder in `eval_scripts` there are three SLURM files. (baseline model, model trained with ICGAN aug, model trained with stablediff aug)
Configs for training are in the `solo-learn/scripts/pretrain` folder. You can find the config files for each model and dataset in the respective folders. You need to set **path for the dataset** and **dir to save model** in each respective config file before submitting the job. By choosing the desired config you can train the methods on the ImageNet, run:

Similarly for evaluation, you just need to submit the slurm file related to the dataset you want. Again, you need to specify the path to the virtual environment, the dataset and the related checkpoint in each SLURM file. For example, command below run the experiment of evaluating model trained with stable diffusion augmentations on Food101:
`sbatch scripts/solo_learn/train_solo_learn.slrm`

`sbatch GenerativeSSL/scripts/eval_scripts/food101/stablediff.slrm`

## Pretrained Models

We also provide the checkpoints for all the trained models here in the [LINK](https://drive.google.com/drive/folders/1xPIbf1cOPqzIzuZ185GjAprA8XmQ0Tvu)
### Evaluation

Configs for evaluation are in the `solo-learn/scripts/linear` folder. You can find the config files for each model and dataset in the respective folders. You need to set **path for the dataset**, **dir to save model** and **path to pretrained feature extractor** in each respective config file before submitting the job. By choosing the desired config you can train the methods on the ImageNet, run:

`sbatch scripts/solo_learn/eval_solo_learn.slrm`
Binary file added images/GenSSL_last-main.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
56 changes: 0 additions & 56 deletions pyproject.toml

This file was deleted.

10 changes: 0 additions & 10 deletions requirements.txt

This file was deleted.

47 changes: 0 additions & 47 deletions scripts/eval_scripts/CIFAR10/baseline.slrm

This file was deleted.

47 changes: 0 additions & 47 deletions scripts/eval_scripts/CIFAR10/icgan.slrm

This file was deleted.

47 changes: 0 additions & 47 deletions scripts/eval_scripts/CIFAR10/stablediff.slrm

This file was deleted.

Loading
Loading