Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
acb48e2
Distributed rcdm
sanaAyrml Jan 24, 2024
cdec0c9
Add rcdm model
sanaAyrml Jan 25, 2024
23cdf80
Add checkpointing
sanaAyrml Jan 26, 2024
ae42737
update
sanaAyrml Jan 26, 2024
73c580b
update
sanaAyrml Jan 26, 2024
25936dc
update config
sanaAyrml Jan 26, 2024
2d9d197
edit logging
sanaAyrml Jan 26, 2024
422a10d
server
sanaAyrml Jan 30, 2024
60d3dc3
add checkpointing
sanaAyrml Jan 30, 2024
9ae6ee8
Add eval files
sanaAyrml Feb 1, 2024
642ad8a
add slrm file
sanaAyrml Feb 1, 2024
ea7cc42
update eval file
sanaAyrml Feb 1, 2024
a9cbba1
update eval
sanaAyrml Feb 1, 2024
90ea717
check eval
sanaAyrml Feb 1, 2024
1cba4c8
check
sanaAyrml Feb 1, 2024
6a34a85
check state_dicts
sanaAyrml Feb 1, 2024
86dcdf7
check
sanaAyrml Feb 1, 2024
4e683a4
edit eval classes
sanaAyrml Feb 1, 2024
ecfab47
check
sanaAyrml Feb 1, 2024
20795af
update slrm
sanaAyrml Feb 1, 2024
38a86e6
Update eval
sanaAyrml Feb 1, 2024
d85de17
edit
sanaAyrml Feb 1, 2024
e359965
edit slrm
sanaAyrml Feb 1, 2024
c78abd0
correct sample slrm
sanaAyrml Feb 1, 2024
a1fa4ea
fix multi gpu
sanaAyrml Feb 1, 2024
fc054cb
Merge branch 'main' into add_simclr_eval
sanaAyrml Feb 12, 2024
f96c121
Delete pytest
sanaAyrml Feb 12, 2024
73eb9e4
update eval
sanaAyrml Feb 13, 2024
e67b598
clean code
sanaAyrml Feb 13, 2024
56688ca
clean code
sanaAyrml Feb 13, 2024
084d05c
clean code
sanaAyrml Feb 13, 2024
1b9ccf3
clean code
sanaAyrml Feb 13, 2024
c707014
clean code
sanaAyrml Feb 13, 2024
b36c694
debug rcdm error
sanaAyrml Feb 13, 2024
cf9713d
edit
sanaAyrml Feb 13, 2024
f8e8eb1
edit
sanaAyrml Feb 13, 2024
5755923
delete normalize
sanaAyrml Feb 13, 2024
25d250f
edit
sanaAyrml Feb 13, 2024
5754ae1
delete print
sanaAyrml Feb 13, 2024
509530f
Merge branch 'main' into add_simclr_eval
sanaAyrml Feb 16, 2024
0642905
update evaluation
sanaAyrml Feb 20, 2024
245eb54
update formating
sanaAyrml Feb 20, 2024
33c0c93
update logging
sanaAyrml Feb 23, 2024
31409dd
Update bash files
sanaAyrml Feb 23, 2024
eb9a4b6
Update augmentation and saving file
sanaAyrml Feb 23, 2024
288f749
update evaluation
sanaAyrml Feb 23, 2024
3b43d4e
Update bash file
sanaAyrml Feb 23, 2024
e64fe5c
edit eval
sanaAyrml Feb 23, 2024
ef6a214
check loading
sanaAyrml Feb 23, 2024
cef4572
debug eval
sanaAyrml Feb 23, 2024
883d9c0
update
sanaAyrml Feb 23, 2024
c900b9f
check evaluation
sanaAyrml Feb 23, 2024
314633c
Clean the code
sanaAyrml Feb 23, 2024
ed78e17
update
sanaAyrml Feb 23, 2024
a392667
try catch the file exist error
sanaAyrml Feb 23, 2024
6c5cf20
update
sanaAyrml Feb 23, 2024
313b705
update logging part
sanaAyrml Feb 23, 2024
67b5a9f
update slrm scripts
sanaAyrml Feb 25, 2024
1e2ba81
update resnet pretrained
sanaAyrml Feb 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,6 @@ repos:
- id: nbqa-ruff
args: [--fix]

- repo: local
hooks:
- id: pytest
name: pytest
entry: python3 -m pytest
language: system
pass_filenames: false
always_run: true

exclude: |
(?x)(
^rcdm/|
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from torchvision import datasets, transforms

from SimCLR.data_aug.gaussian_blur import GaussianBlur
from SimCLR.data_aug.icgan_aug import ICGANInference
from SimCLR.data_aug.icgan_config import get_icgan_config
from SimCLR.data_aug.rcdm_aug import RCDMInference
from SimCLR.data_aug.rcdm_config import get_config
from SimCLR.data_aug.view_generator import ContrastiveLearningViewGenerator
from SimCLR.datasets.data_aug.gaussian_blur import GaussianBlur
from SimCLR.datasets.data_aug.icgan_aug import ICGANInference
from SimCLR.datasets.data_aug.icgan_config import get_icgan_config
from SimCLR.datasets.data_aug.rcdm_aug import RCDMInference
from SimCLR.datasets.data_aug.rcdm_config import get_config
from SimCLR.datasets.view_generator import ContrastiveLearningViewGenerator
from SimCLR.exceptions.exceptions import InvalidDatasetSelection


Expand Down Expand Up @@ -114,3 +114,4 @@ def get_dataset(
raise InvalidDatasetSelection()
else:
return dataset_fn()

44 changes: 44 additions & 0 deletions SimCLR/datasets/data_aug/center_crop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch.nn.functional as F
import torchvision
import torch

class CostumeCenterCrop(torch.nn.Module):
def __init__(self, size=None, ratio="1:1"):
super().__init__()
self.size = size
self.ratio = ratio
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be cropped.

Returns:
PIL Image or Tensor: Cropped image.
"""
if self.size is None:
if isinstance(img, torch.Tensor):
h, w = img.shape[-2:]
else:
w, h = img.size
ratio = self.ratio.split(":")
ratio = float(ratio[0]) / float(ratio[1])
# Size must match the ratio while cropping to the edge of the image
ratioed_w = int(h * ratio)
ratioed_h = int(w / ratio)
if w>=h:
if ratioed_h <= h:
size = (ratioed_h, w)
else:
size = (h, ratioed_w)
else:
if ratioed_w <= w:
size = (h, ratioed_w)
else:
size = (ratioed_h, w)
else:
size = self.size
return torchvision.transforms.functional.center_crop(img, size)

def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size})"

File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, config, device_id):

# Load SSL model
self.ssl_model = (
get_model(self.config.type_model, self.config.use_head)
get_model(self.config.type_model, self.config.use_head, self.config.pretrained_models_dir)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain this change? Why wasn't this passed before?

.cuda(self.device_id)
.eval()
)
Expand Down Expand Up @@ -50,7 +50,7 @@ def __init__(self, config, device_id):

if self.config.model_path == "":
trained_model = get_dict_rcdm_model(
self.config.type_model, self.config.use_head
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

self.config.type_model, self.config.use_head, self.config.pretrained_models_dir
)
else:
trained_model = torch.load(self.config.model_path, map_location="cpu")
Expand All @@ -63,7 +63,6 @@ def preprocess_input_image(self, input_image, size=224):
data_utils.CenterCropLongEdge(),
transforms.Resize((size, size)),
transforms.ToTensor(),
transforms.Normalize(self.config.norm_mean, self.config.norm_std),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you remove this? Is it causing an error?

]
)
tensor_image = transform_list(input_image)
Expand All @@ -89,9 +88,8 @@ def __call__(self, img):
if not self.config.use_ddim
else self.diffusion.ddim_sample_loop
)

img = img.unsqueeze(0).repeat(1, 1, 1, 1)
img = self.preprocess_input_image(img).cuda(self.device_id)
img = img.repeat(1, 1, 1, 1)
model_kwargs = {}

with torch.no_grad():
Expand All @@ -104,5 +102,4 @@ def __call__(self, img):
model_kwargs=model_kwargs,
)

print("Sampling completed!")
return sample.squeeze(0)
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
import ml_collections


def get_config():
config = ml_collections.ConfigDict()
config.image_size = 128 # The size of the images to generate.
config.class_cond = False # If true, use class conditional generation.
config.type_model = "simclr" # Type of model to use (e.g., simclr, dino).
config.use_head = False # If true, use the projector/head for SSL representation.
config.model_path = "" # Replace with the path to your model if you have one.
config.use_ddim = False # If true, use DDIM sampler.
config.no_shared = True # If false, enables squeeze and excitation.
config.clip_denoised = True # If true, clip denoised images.
config.attention_resolutions = "32,16,8" # Resolutions to use for attention layers.
config.diffusion_steps = 100 # Number of diffusion steps.
config.learn_sigma = True # If true, learn the noise level.
config.noise_schedule = "linear" # Type of noise schedule (e.g., linear).
config.num_channels = 256 # Number of channels in the model.
config.num_heads = 4 # Number of attention heads.
config.num_res_blocks = 2 # Number of residual blocks.
config.resblock_updown = True # If true, use up/down sampling in resblocks.
config.use_fp16 = False # If true, use 16-bit floating point precision.
config.use_scale_shift_norm = True # If true, use scale-shift normalization.
config.ssl_image_size = 224 # Size of the input images for the SSL model.
config.ssl_image_channels = (
3 # Number of channels of the input images for the SSL model.
)

return config
import ml_collections


def get_config():
config = ml_collections.ConfigDict()
config.image_size = 128 # The size of the images to generate.
config.class_cond = False # If true, use class conditional generation.
config.pretrained_models_dir = "/ssd003/projects/aieng/genssl" # Path to the directory containing the model.
config.type_model = "simclr" # Type of model to use (e.g., simclr, dino).
config.use_head = False # If true, use the projector/head for SSL representation.
config.model_path = "" # Replace with the path to your model if you have one.
config.use_ddim = True # If true, use DDIM sampler.
config.no_shared = True # If false, enables squeeze and excitation.
config.clip_denoised = True # If true, clip denoised images.
config.attention_resolutions = "32,16,8" # Resolutions to use for attention layers.
config.diffusion_steps = 100 # Number of diffusion steps.
config.learn_sigma = True # If true, learn the noise level.
config.noise_schedule = "linear" # Type of noise schedule (e.g., linear).
config.num_channels = 256 # Number of channels in the model.
config.num_heads = 4 # Number of attention heads.
config.num_res_blocks = 2 # Number of residual blocks.
config.resblock_updown = True # If true, use up/down sampling in resblocks.
config.use_fp16 = False # If true, use 16-bit floating point precision.
config.use_scale_shift_norm = True # If true, use scale-shift normalization.
config.ssl_image_size = 224 # Size of the input images for the SSL model.
config.ssl_image_channels = (
3 # Number of channels of the input images for the SSL model.
)
config.timestep_respacing = "ddim2" # Type of timestep respacing (e.g., ddim25).

return config
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update config to available ddim.

57 changes: 57 additions & 0 deletions SimCLR/datasets/supervised_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from torchvision import datasets, transforms
from torchvision.transforms import transforms
Comment on lines +1 to +2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These lines both import transforms which could cause confusion. Is one of them unnecessary and can be removed?


from SimCLR.exceptions.exceptions import InvalidDatasetSelection
from SimCLR.datasets.data_aug.center_crop import CostumeCenterCrop
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we change this? (see the comment on the module)


class SupervisedDataset:
def __init__(self, root_folder):
self.root_folder = root_folder

@staticmethod
def get_transform(size):
"""Return a set of simple transformations for supervised learning.

Args:
size (int): Image size.
"""
transform_list = [
CostumeCenterCrop(),
transforms.Resize((size, size)),
transforms.ToTensor(),
]

return transforms.Compose(transform_list)


def get_dataset(self, name, train = True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: it should be train=True

if name == "imagenet":
if train:
split = "train"
else:
split = "val"
return datasets.ImageNet(
self.root_folder,
split=split,
transform=self.get_transform(224),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we read these constants from a config?

)
elif name == "cifar10":
return datasets.CIFAR10(
self.root_folder,
train=train,
transform= self.get_transform(32),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

download=True,
)
elif name == "stl10":
if train:
split = "train"
else:
split = "test"
return datasets.STL10(
self.root_folder,
split=split,
transform=self.get_transform(96),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

download=True,
)
else:
raise InvalidDatasetSelection()
File renamed without changes.
64 changes: 64 additions & 0 deletions SimCLR/models/resnet_pretrained.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import torch
from torch import nn
from torchvision import models

from ..exceptions.exceptions import InvalidBackboneError
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use absolute paths.



class PretrainedResNet(nn.Module):
def __init__(self, base_model, pretrained_model_file, linear_eval=True, num_classes=10):
super(PretrainedResNet, self).__init__()

self.pretrained_model_file = pretrained_model_file

self.resnet_dict = {
"resnet18": models.resnet18(pretrained=False, num_classes=num_classes),
"resnet50": models.resnet50(pretrained=False, num_classes=num_classes),
}

self.backbone = self._get_basemodel(base_model)

# load pretrained weights
log = self._load_pretrained()

assert log.missing_keys == ["fc.weight", "fc.bias"]

if linear_eval:
# freeze all layers but the last fc
self._freeze_backbone()
parameters = list(filter(lambda p: p.requires_grad, self.backbone.parameters()))
assert len(parameters) == 2 # fc.weight, fc.bias

def _load_pretrained(self):
checkpoint = torch.load(self.pretrained_model_file, map_location='cpu')
state_dict = checkpoint["state_dict"]
for k in list(state_dict.keys()):
if k.startswith("module.backbone."):
if not k.startswith("module.backbone.fc"):
# remove prefix
state_dict[k[len("module.backbone.") :]] = state_dict[k]
del state_dict[k]
log = self.backbone.load_state_dict(state_dict, strict=False)
return log


def _freeze_backbone(self):
# freeze all layers but the last fc
for name, param in self.backbone.named_parameters():
if name not in ["fc.weight", "fc.bias"]:
param.requires_grad = False
return


def _get_basemodel(self, model_name):
try:
model = self.resnet_dict[model_name]
except KeyError:
raise InvalidBackboneError(
"Invalid backbone architecture. Check the config file and pass one of: resnet18 or resnet50",
)
else:
return model

def forward(self, x):
return self.backbone(x)
12 changes: 10 additions & 2 deletions SimCLR/simclr.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from datetime import datetime

import torch
from torch.cuda.amp import GradScaler, autocast
Expand All @@ -18,7 +19,15 @@ def __init__(self, *args, **kwargs):
self.optimizer = kwargs["optimizer"]
self.scheduler = kwargs["scheduler"]
self.device_id = kwargs["device_id"]
self.writer = SummaryWriter()
# Create a directory to save the model checkpoints and logs
now = datetime.now()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added logic to run_simCLR.py that does something like this. So you have to either remove these parts, or remove my changes before merging.

dt_string = now.strftime("%Y_%m_%d_%H_%M")
log_dir = os.path.join(args.model_dir, args.experiment_name,dt_string)
try:
os.makedirs(log_dir)
except FileExistsError:
print(f"Directory {log_dir} made by another worker", flush=True)
self.writer = SummaryWriter(log_dir)
self.criterion = loss.SimCLRContrastiveLoss(self.args.temperature).cuda(
self.device_id
)
Expand Down Expand Up @@ -62,7 +71,6 @@ def train(self, train_loader):
self.scheduler.get_last_lr()[0],
global_step=n_iter,
)

n_iter += 1

# warmup for the first 10 epochs
Expand Down
51 changes: 51 additions & 0 deletions eval_simclr.slrm
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/bin/bash

#SBATCH --job-name=train_sunrgbd
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename

#SBATCH --partition=t4v2
#SBATCH --nodes=1
#SBATCH --gres=gpu:4
#SBATCH --ntasks-per-node=4
#SBATCH --cpus-per-task=4
#SBATCH --mem=100G
#SBATCH --output=logs/simclr/eval_slurm-%N-%j.out
#SBATCH --error=logs/simclr/eval_slurm-%N-%j.err
#SBATCH --qos=m

PY_ARGS=${@:1}

# load virtual environment
source /ssd003/projects/aieng/envs/genssl2/bin/activate

export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend
export CUDA_LAUNCH_BLOCKING=1

export MASTER_ADDR=$(hostname)
export MASTER_PORT=45679

export PYTHONPATH="."
nvidia-smi

pretrained_model_dir="/projects/imagenet_synthetic/train_models"
experiment_name="simclr/2024_02_23_13_02"

cd $pretrained_model_dir/$experiment_name

files=$(ls checkpoint_epoch_*)

cd "$OLDPWD"

# Loop through each file and pass it as a parameter to the rest of the script
for file in $files
do
echo "Evaluating: $file"

# srun execute ntasks-per-node * nodes times
srun python evaluate_simCLR.py \
--distributed_mode \
--batch-size=256 \
--pretrained_model_dir=$pretrained_model_dir \
--experiment_name=$experiment_name \
--pretrained_model_name=$file \
--linear_evaluation
# Add your processing logic here
done
Loading