Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
0eeb86e
Add simsiam.
afkanpour Feb 27, 2024
260c98b
Adding simsiam files.
afkanpour Feb 27, 2024
0f19437
Add SimSiam eval script.
afkanpour Feb 28, 2024
c724279
Add slurm files for eval runs.
afkanpour Feb 29, 2024
19dbce5
Added tqdm.
afkanpour Feb 29, 2024
a2dfc9f
Modified eval scripts to use epoch 50 checkpoint.
afkanpour Feb 29, 2024
9c17abb
Add Lars to linear evaluation
fereshtehforghani Mar 1, 2024
eb8964b
Added a linear eval scrip that prints out the original logs
afkanpour Mar 1, 2024
a16c5a8
changes to multiple scripts
afkanpour Mar 4, 2024
85fd02a
resolving merge conflicts.
afkanpour Mar 4, 2024
18d8798
add adil multi node script
sanaAyrml Mar 4, 2024
984cc36
add add multi gpu to original loging
sanaAyrml Mar 4, 2024
653c80c
add eval scripts
sanaAyrml Mar 4, 2024
4762d19
update training
sanaAyrml Mar 4, 2024
16e683d
changed the train slurm script
afkanpour Mar 5, 2024
71b256b
Add food101 and places365 to linear evaluation script
fereshtehforghani Mar 5, 2024
c2fcc80
add original eval
sanaAyrml Mar 6, 2024
2bc88d7
Merge branch 'simsiam' of github.com:VectorInstitute/GenerativeSSL in…
fereshtehforghani Mar 6, 2024
6c3f957
Add food101 and cifar10 to linear eval code
fereshtehforghani Mar 6, 2024
0d176fb
Add places365 dataset to linear eval.
fereshtehforghani Mar 6, 2024
5ca2ff4
minor changes.
afkanpour Mar 6, 2024
0f2dedd
add checkpoint_dir flag to eval script.
afkanpour Mar 6, 2024
582bc4d
Fix error.
afkanpour Mar 6, 2024
2fccb3a
add inaturalist
sanaAyrml Mar 6, 2024
0bc3f3f
fix inat
sanaAyrml Mar 6, 2024
a4a8ef9
fix inaturalist script
sanaAyrml Mar 6, 2024
06d77dd
update class head
sanaAyrml Mar 6, 2024
6d3584b
add cifar100
sanaAyrml Mar 7, 2024
c98b3e6
minor changes
afkanpour Mar 7, 2024
ceb3bcc
merge
afkanpour Mar 7, 2024
8c34ca1
add icgan ablation
sanaAyrml Mar 11, 2024
27311ae
add icgan ablation
sanaAyrml Mar 11, 2024
35a875b
debug
sanaAyrml Mar 11, 2024
e2830b9
update num classes
sanaAyrml Mar 11, 2024
95addac
update code
sanaAyrml Mar 11, 2024
f044643
update linear eval code
sanaAyrml Mar 11, 2024
ee0a674
update icgan code
sanaAyrml Mar 11, 2024
6efa7ef
correct code
sanaAyrml Mar 11, 2024
dbf6c4f
add sanity check
sanaAyrml Mar 11, 2024
65b44c7
add linear eval files with clip
vahid0001 Mar 13, 2024
c2a4a9c
add slurm scripts for experiments
vahid0001 Mar 14, 2024
f88558f
Update gitignore
sanaAyrml Mar 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,6 @@ dmypy.json

# pycharm
.idea/

# Trained models
trained_models/
39 changes: 29 additions & 10 deletions SimCLR/data_aug/imagenet_synthetic_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import random

import torch
from PIL import Image
from torchvision import datasets, transforms

Expand Down Expand Up @@ -37,6 +38,7 @@ def __init__(
imagenet_synthetic_root,
index_min=0,
index_max=9,
generative_augmentation_prob=None,
load_one_real_image=False,
split="train",
):
Expand All @@ -48,6 +50,7 @@ def __init__(
self.imagenet_synthetic_root = imagenet_synthetic_root
self.index_min = index_min
self.index_max = index_max
self.generative_augmentation_prob = generative_augmentation_prob
self.load_one_real_image = load_one_real_image
self.synthetic_transforms = _get_simclr_transforms(size=224)
self.real_transforms = _get_simclr_transforms(size=224, random_crop=True)
Expand All @@ -62,21 +65,37 @@ def _synthetic_image(filename):
filename_parent_dir = filename.split("/")[-2]
image_path = os.path.join(
self.imagenet_synthetic_root,
# self.split,
self.split,
filename_parent_dir,
filename_and_extension.split(".")[0] + f"_{rand_int}.JPEG",
)
return Image.open(image_path).convert("RGB")

if self.load_one_real_image:
image1 = self.loader(os.path.join(self.root, imagenet_filename))
image1 = self.real_transforms(image1)
else:
image1 = _synthetic_image(imagenet_filename)
image1 = self.synthetic_transforms(image1)
if self.generative_augmentation_prob is not None:
if torch.rand(1) < self.generative_augmentation_prob:
# Generate a synthetic image.
image1 = _synthetic_image(imagenet_filename)
image1 = self.synthetic_transforms(image1)
else:
image1 = self.loader(os.path.join(self.root, imagenet_filename))
image1 = self.real_transforms(image1)

# image2 is always synthetic.
image2 = _synthetic_image(imagenet_filename)
image2 = self.synthetic_transforms(image2)
if torch.rand(1) < self.generative_augmentation_prob:
# Generate another synthetic image.
image2 = _synthetic_image(imagenet_filename)
image2 = self.synthetic_transforms(image2)
else:
image2 = self.loader(os.path.join(self.root, imagenet_filename))
image2 = self.real_transforms(image2)
else:
if self.load_one_real_image:
image1 = self.loader(os.path.join(self.root, imagenet_filename))
image1 = self.real_transforms(image1)
else:
image1 = _synthetic_image(imagenet_filename)
image1 = self.synthetic_transforms(image1)
# image2 is always synthetic.
image2 = _synthetic_image(imagenet_filename)
image2 = self.synthetic_transforms(image2)

return {"view1": image1, "view2": image2}, label
20 changes: 17 additions & 3 deletions SimCLR/simclr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@ def __init__(self, *args, **kwargs):
self.device_id,
)
self.checkpoint_dir = self.args.checkpoint_dir
self.start_epoch = 0

if self.args.last_checkpoint:
checkpoint = torch.load(self.args.last_checkpoint)
self.model.load_state_dict(checkpoint["state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer"])
# Start from the next epoch.
self.start_epoch = checkpoint["epoch"] + 1
print(
f"Checkpoint loaded. Resuming training from epoch: {self.start_epoch}"
)

def train(self, train_loader):
scaler = GradScaler(enabled=self.args.fp16_precision)
Expand All @@ -32,9 +43,12 @@ def train(self, train_loader):
print(f"Log dir: {self.writer.log_dir}")

n_iter = 0
print(f"Start SimCLR training for {self.args.epochs} epochs.")
print(
f"Start SimCLR training for {self.args.epochs} epochs starting from {self.start_epoch}."
)

for epoch_counter in tqdm(range(self.args.epochs), desc="Training Progress"):
train_range = range(self.start_epoch, self.args.epochs)
for epoch_counter in tqdm(train_range, desc="Training Progress"):
if dist_utils.is_dist_avail_and_initialized():
train_loader.sampler.set_epoch(epoch_counter)
for images, _ in tqdm(train_loader):
Expand Down Expand Up @@ -77,7 +91,7 @@ def train(self, train_loader):
checkpoint_file = os.path.join(self.checkpoint_dir, checkpoint_name)
save_checkpoint(
{
"epoch": self.args.epochs,
"epoch": epoch_counter,
"arch": self.args.arch,
"state_dict": self.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
Expand Down
39 changes: 39 additions & 0 deletions eval_scripts/food101/eval_food101_original_simsiam_100.slrm
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/bin/bash

#SBATCH --job-name="simsiam_eval"
#SBATCH --partition=t4v2
#SBATCH --account=deadline
#SBATCH --qos=deadline
#SBATCH --nodes=1
#SBATCH --gres=gpu:4
#SBATCH --time=36:00:00
#SBATCH --ntasks-per-node=4
#SBATCH --cpus-per-task=4
#SBATCH --mem-per-cpu=8G
#SBATCH --output=slurm-%j.out


PY_ARGS=${@:1}

# load virtual environment
source /ssd003/projects/aieng/envs/genssl2/bin/activate

export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend
export CUDA_LAUNCH_BLOCKING=1

export MASTER_ADDR=$(hostname)
export MASTER_PORT=45679

export PYTHONPATH="."
nvidia-smi

# “srun” executes the script <ntasks-per-node * nodes> times
srun python simsiam/linear_eval_downstream_datasets.py \
--data_dir="/projects/imagenet_synthetic/fereshteh_datasets/" \
--checkpoint_dir="/projects/imagenet_synthetic/model_checkpoints/food101/evaluate_original"\
--arch="resnet50" \
--distributed_mode \
--batch-size=1024 \
--lars \
--dataset_name="food101" \
--pretrained_checkpoint="/projects/imagenet_synthetic/model_checkpoints/_original_simsiam/checkpoint_0099.pth.tar"
40 changes: 40 additions & 0 deletions eval_scripts/food101/eval_food101_simsiam_baseline_100.slrm
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/bin/bash

#SBATCH --job-name="simsiam_eval"
#SBATCH --partition=t4v2
#SBATCH --account=deadline
#SBATCH --qos=deadline
#SBATCH --nodes=1
#SBATCH --gres=gpu:4
#SBATCH --time=36:00:00
#SBATCH --ntasks-per-node=4
#SBATCH --cpus-per-task=4
#SBATCH --mem-per-cpu=8G
#SBATCH --output=slurm-%j.out


PY_ARGS=${@:1}

# load virtual environment
source /ssd003/projects/aieng/envs/genssl2/bin/activate

export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend
export CUDA_LAUNCH_BLOCKING=1

export MASTER_ADDR=$(hostname)
export MASTER_PORT=45679

export PYTHONPATH="."
nvidia-smi

# “srun” executes the script <ntasks-per-node * nodes> times
srun python simsiam/linear_eval_downstream_datasets.py \
--data_dir="/projects/imagenet_synthetic/fereshteh_datasets/" \
--checkpoint_dir="/projects/imagenet_synthetic/model_checkpoints/food101/evaluate_baseline"\
--arch="resnet50" \
--distributed_mode \
--batch-size=1024 \
--lars \
--dataset_name="food101" \
--pretrained_checkpoint="/projects/imagenet_synthetic/model_checkpoints/simsiam_baseline_2024-02-29-14-49/checkpoint_0099.pth.tar"
# --pretrained_checkpoint="/projects/imagenet_synthetic/model_checkpoints/simsiam_baseline_2024-02-29-14-49/checkpoint_0090.pth.tar"
39 changes: 39 additions & 0 deletions eval_scripts/food101/eval_food101_simsiam_icgan_100.slrm
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/bin/bash

#SBATCH --job-name="simsiam_eval"
#SBATCH --partition=t4v2
#SBATCH --account=deadline
#SBATCH --qos=deadline
#SBATCH --nodes=1
#SBATCH --gres=gpu:4
#SBATCH --time=36:00:00
#SBATCH --ntasks-per-node=4
#SBATCH --cpus-per-task=4
#SBATCH --mem-per-cpu=8G
#SBATCH --output=slurm-%j.out


PY_ARGS=${@:1}

# load virtual environment
source /ssd003/projects/aieng/envs/genssl2/bin/activate

export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend
export CUDA_LAUNCH_BLOCKING=1

export MASTER_ADDR=$(hostname)
export MASTER_PORT=45679

export PYTHONPATH="."
nvidia-smi

# “srun” executes the script <ntasks-per-node * nodes> times
srun python simsiam/linear_eval_downstream_datasets.py \
--data_dir="/projects/imagenet_synthetic/fereshteh_datasets/" \
--checkpoint_dir="/projects/imagenet_synthetic/model_checkpoints/food101/evaluate_icgan"\
--arch="resnet50" \
--distributed_mode \
--batch-size=1024 \
--lars \
--dataset_name="food101" \
--pretrained_checkpoint="/projects/imagenet_synthetic/model_checkpoints/simsiam_icgan_2024-02-29-18-40/checkpoint_0099.pth.tar"
39 changes: 39 additions & 0 deletions eval_scripts/food101/eval_food101_simsiam_stablediff_100.slrm
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/bin/bash

#SBATCH --job-name="simsiam_eval"
#SBATCH --partition=t4v2
#SBATCH --account=deadline
#SBATCH --qos=deadline
#SBATCH --nodes=1
#SBATCH --gres=gpu:4
#SBATCH --time=36:00:00
#SBATCH --ntasks-per-node=4
#SBATCH --cpus-per-task=4
#SBATCH --mem-per-cpu=8G
#SBATCH --output=slurm-%j.out


PY_ARGS=${@:1}

# load virtual environment
source /ssd003/projects/aieng/envs/genssl2/bin/activate

export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend
export CUDA_LAUNCH_BLOCKING=1

export MASTER_ADDR=$(hostname)
export MASTER_PORT=45679

export PYTHONPATH="."
nvidia-smi

# “srun” executes the script <ntasks-per-node * nodes> times
srun python simsiam/linear_eval_downstream_datasets.py \
--data_dir="/projects/imagenet_synthetic/fereshteh_datasets/" \
--checkpoint_dir="/projects/imagenet_synthetic/model_checkpoints/food101/evaluate_stable_diff"\
--arch="resnet50" \
--distributed_mode \
--batch-size=1024 \
--lars \
--dataset_name="food101" \
--pretrained_checkpoint="/projects/imagenet_synthetic/model_checkpoints/simsiam_stablediff_2024-02-29-15-27/checkpoint_0099.pth.tar"
39 changes: 39 additions & 0 deletions eval_scripts/places365/eval_places365_original_simsiam_100.slrm
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/bin/bash

#SBATCH --job-name="simsiam_eval"
#SBATCH --partition=t4v2
#SBATCH --account=deadline
#SBATCH --qos=deadline
#SBATCH --nodes=1
#SBATCH --gres=gpu:4
#SBATCH --time=36:00:00
#SBATCH --ntasks-per-node=4
#SBATCH --cpus-per-task=4
#SBATCH --mem-per-cpu=8G
#SBATCH --output=slurm-%j.out


PY_ARGS=${@:1}

# load virtual environment
source /ssd003/projects/aieng/envs/genssl2/bin/activate

export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend
export CUDA_LAUNCH_BLOCKING=1

export MASTER_ADDR=$(hostname)
export MASTER_PORT=45679

export PYTHONPATH="."
nvidia-smi

# “srun” executes the script <ntasks-per-node * nodes> times
srun python simsiam/linear_eval_downstream_datasets.py \
--data_dir="/projects/imagenet_synthetic/fereshteh_datasets/places365/" \
--checkpoint_dir="/projects/imagenet_synthetic/model_checkpoints/places365/evaluate_original"\
--arch="resnet50" \
--distributed_mode \
--batch-size=1024 \
--lars \
--dataset_name="places365" \
--pretrained_checkpoint="/projects/imagenet_synthetic/model_checkpoints/_original_simsiam/checkpoint_0099.pth.tar"
39 changes: 39 additions & 0 deletions eval_scripts/places365/eval_places365_simsiam_baseline_100.slrm
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/bin/bash

#SBATCH --job-name="simsiam_eval"
#SBATCH --partition=t4v2
#SBATCH --account=deadline
#SBATCH --qos=deadline
#SBATCH --nodes=1
#SBATCH --gres=gpu:4
#SBATCH --time=36:00:00
#SBATCH --ntasks-per-node=4
#SBATCH --cpus-per-task=4
#SBATCH --mem-per-cpu=8G
#SBATCH --output=slurm-%j.out


PY_ARGS=${@:1}

# load virtual environment
source /ssd003/projects/aieng/envs/genssl2/bin/activate

export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend
export CUDA_LAUNCH_BLOCKING=1

export MASTER_ADDR=$(hostname)
export MASTER_PORT=45679

export PYTHONPATH="."
nvidia-smi

# “srun” executes the script <ntasks-per-node * nodes> times
srun python simsiam/linear_eval_downstream_datasets.py \
--data_dir="/projects/imagenet_synthetic/fereshteh_datasets/places365/" \
--checkpoint_dir="/projects/imagenet_synthetic/model_checkpoints/places365/evaluate_baseline"\
--arch="resnet50" \
--distributed_mode \
--batch-size=1024 \
--lars \
--dataset_name="places365" \
--pretrained_checkpoint="/projects/imagenet_synthetic/model_checkpoints/simsiam_baseline_2024-02-29-14-49/checkpoint_0099.pth.tar"
Loading