-
Notifications
You must be signed in to change notification settings - Fork 0
Add simclr eval #13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add simclr eval #13
Changes from all commits
acb48e2
cdec0c9
23cdf80
ae42737
73c580b
25936dc
2d9d197
422a10d
60d3dc3
9ae6ee8
642ad8a
ea7cc42
a9cbba1
90ea717
1cba4c8
6a34a85
86dcdf7
4e683a4
ecfab47
20795af
38a86e6
d85de17
e359965
c78abd0
a1fa4ea
fc054cb
f96c121
73eb9e4
e67b598
56688ca
084d05c
1b9ccf3
c707014
b36c694
cf9713d
f8e8eb1
5755923
25d250f
5754ae1
509530f
0642905
245eb54
33c0c93
31409dd
eb9a4b6
288f749
3b43d4e
e64fe5c
ef6a214
cef4572
883d9c0
c900b9f
314633c
ed78e17
a392667
6c5cf20
313b705
67b5a9f
1e2ba81
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| import torch.nn.functional as F | ||
| import torchvision | ||
| import torch | ||
|
|
||
| class CostumeCenterCrop(torch.nn.Module): | ||
| def __init__(self, size=None, ratio="1:1"): | ||
| super().__init__() | ||
| self.size = size | ||
| self.ratio = ratio | ||
| def forward(self, img): | ||
| """ | ||
| Args: | ||
| img (PIL Image or Tensor): Image to be cropped. | ||
|
|
||
| Returns: | ||
| PIL Image or Tensor: Cropped image. | ||
| """ | ||
| if self.size is None: | ||
| if isinstance(img, torch.Tensor): | ||
| h, w = img.shape[-2:] | ||
| else: | ||
| w, h = img.size | ||
| ratio = self.ratio.split(":") | ||
| ratio = float(ratio[0]) / float(ratio[1]) | ||
| # Size must match the ratio while cropping to the edge of the image | ||
| ratioed_w = int(h * ratio) | ||
| ratioed_h = int(w / ratio) | ||
| if w>=h: | ||
| if ratioed_h <= h: | ||
| size = (ratioed_h, w) | ||
| else: | ||
| size = (h, ratioed_w) | ||
| else: | ||
| if ratioed_w <= w: | ||
| size = (h, ratioed_w) | ||
| else: | ||
| size = (ratioed_h, w) | ||
| else: | ||
| size = self.size | ||
| return torchvision.transforms.functional.center_crop(img, size) | ||
|
|
||
| def __repr__(self) -> str: | ||
| return f"{self.__class__.__name__}(size={self.size})" | ||
|
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,7 +21,7 @@ def __init__(self, config, device_id): | |
|
|
||
| # Load SSL model | ||
| self.ssl_model = ( | ||
| get_model(self.config.type_model, self.config.use_head) | ||
| get_model(self.config.type_model, self.config.use_head, self.config.pretrained_models_dir) | ||
afkanpour marked this conversation as resolved.
Show resolved
Hide resolved
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you explain this change? Why wasn't this passed before? |
||
| .cuda(self.device_id) | ||
| .eval() | ||
| ) | ||
|
|
@@ -50,7 +50,7 @@ def __init__(self, config, device_id): | |
|
|
||
| if self.config.model_path == "": | ||
| trained_model = get_dict_rcdm_model( | ||
| self.config.type_model, self.config.use_head | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here |
||
| self.config.type_model, self.config.use_head, self.config.pretrained_models_dir | ||
| ) | ||
| else: | ||
| trained_model = torch.load(self.config.model_path, map_location="cpu") | ||
|
|
@@ -63,7 +63,6 @@ def preprocess_input_image(self, input_image, size=224): | |
| data_utils.CenterCropLongEdge(), | ||
| transforms.Resize((size, size)), | ||
| transforms.ToTensor(), | ||
| transforms.Normalize(self.config.norm_mean, self.config.norm_std), | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why did you remove this? Is it causing an error? |
||
| ] | ||
| ) | ||
| tensor_image = transform_list(input_image) | ||
|
|
@@ -89,9 +88,8 @@ def __call__(self, img): | |
| if not self.config.use_ddim | ||
| else self.diffusion.ddim_sample_loop | ||
| ) | ||
|
|
||
| img = img.unsqueeze(0).repeat(1, 1, 1, 1) | ||
| img = self.preprocess_input_image(img).cuda(self.device_id) | ||
| img = img.repeat(1, 1, 1, 1) | ||
| model_kwargs = {} | ||
|
|
||
| with torch.no_grad(): | ||
|
|
@@ -104,5 +102,4 @@ def __call__(self, img): | |
| model_kwargs=model_kwargs, | ||
| ) | ||
|
|
||
| print("Sampling completed!") | ||
| return sample.squeeze(0) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,29 +1,31 @@ | ||
| import ml_collections | ||
|
|
||
|
|
||
| def get_config(): | ||
| config = ml_collections.ConfigDict() | ||
| config.image_size = 128 # The size of the images to generate. | ||
| config.class_cond = False # If true, use class conditional generation. | ||
| config.type_model = "simclr" # Type of model to use (e.g., simclr, dino). | ||
| config.use_head = False # If true, use the projector/head for SSL representation. | ||
| config.model_path = "" # Replace with the path to your model if you have one. | ||
| config.use_ddim = False # If true, use DDIM sampler. | ||
| config.no_shared = True # If false, enables squeeze and excitation. | ||
| config.clip_denoised = True # If true, clip denoised images. | ||
| config.attention_resolutions = "32,16,8" # Resolutions to use for attention layers. | ||
| config.diffusion_steps = 100 # Number of diffusion steps. | ||
| config.learn_sigma = True # If true, learn the noise level. | ||
| config.noise_schedule = "linear" # Type of noise schedule (e.g., linear). | ||
| config.num_channels = 256 # Number of channels in the model. | ||
| config.num_heads = 4 # Number of attention heads. | ||
| config.num_res_blocks = 2 # Number of residual blocks. | ||
| config.resblock_updown = True # If true, use up/down sampling in resblocks. | ||
| config.use_fp16 = False # If true, use 16-bit floating point precision. | ||
| config.use_scale_shift_norm = True # If true, use scale-shift normalization. | ||
| config.ssl_image_size = 224 # Size of the input images for the SSL model. | ||
| config.ssl_image_channels = ( | ||
| 3 # Number of channels of the input images for the SSL model. | ||
| ) | ||
|
|
||
| return config | ||
| import ml_collections | ||
|
|
||
|
|
||
| def get_config(): | ||
| config = ml_collections.ConfigDict() | ||
| config.image_size = 128 # The size of the images to generate. | ||
| config.class_cond = False # If true, use class conditional generation. | ||
| config.pretrained_models_dir = "/ssd003/projects/aieng/genssl" # Path to the directory containing the model. | ||
| config.type_model = "simclr" # Type of model to use (e.g., simclr, dino). | ||
| config.use_head = False # If true, use the projector/head for SSL representation. | ||
| config.model_path = "" # Replace with the path to your model if you have one. | ||
| config.use_ddim = True # If true, use DDIM sampler. | ||
| config.no_shared = True # If false, enables squeeze and excitation. | ||
| config.clip_denoised = True # If true, clip denoised images. | ||
| config.attention_resolutions = "32,16,8" # Resolutions to use for attention layers. | ||
| config.diffusion_steps = 100 # Number of diffusion steps. | ||
| config.learn_sigma = True # If true, learn the noise level. | ||
| config.noise_schedule = "linear" # Type of noise schedule (e.g., linear). | ||
| config.num_channels = 256 # Number of channels in the model. | ||
| config.num_heads = 4 # Number of attention heads. | ||
| config.num_res_blocks = 2 # Number of residual blocks. | ||
| config.resblock_updown = True # If true, use up/down sampling in resblocks. | ||
| config.use_fp16 = False # If true, use 16-bit floating point precision. | ||
| config.use_scale_shift_norm = True # If true, use scale-shift normalization. | ||
| config.ssl_image_size = 224 # Size of the input images for the SSL model. | ||
| config.ssl_image_channels = ( | ||
| 3 # Number of channels of the input images for the SSL model. | ||
| ) | ||
| config.timestep_respacing = "ddim2" # Type of timestep respacing (e.g., ddim25). | ||
|
|
||
| return config | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update config to available ddim. |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| from torchvision import datasets, transforms | ||
| from torchvision.transforms import transforms | ||
|
Comment on lines
+1
to
+2
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These lines both import |
||
|
|
||
| from SimCLR.exceptions.exceptions import InvalidDatasetSelection | ||
| from SimCLR.datasets.data_aug.center_crop import CostumeCenterCrop | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we change this? (see the comment on the module) |
||
|
|
||
| class SupervisedDataset: | ||
| def __init__(self, root_folder): | ||
| self.root_folder = root_folder | ||
|
|
||
| @staticmethod | ||
| def get_transform(size): | ||
| """Return a set of simple transformations for supervised learning. | ||
|
|
||
| Args: | ||
| size (int): Image size. | ||
| """ | ||
| transform_list = [ | ||
| CostumeCenterCrop(), | ||
| transforms.Resize((size, size)), | ||
| transforms.ToTensor(), | ||
| ] | ||
|
|
||
| return transforms.Compose(transform_list) | ||
|
|
||
|
|
||
| def get_dataset(self, name, train = True): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: it should be |
||
| if name == "imagenet": | ||
| if train: | ||
| split = "train" | ||
| else: | ||
| split = "val" | ||
| return datasets.ImageNet( | ||
| self.root_folder, | ||
| split=split, | ||
| transform=self.get_transform(224), | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we read these constants from a config? |
||
| ) | ||
| elif name == "cifar10": | ||
| return datasets.CIFAR10( | ||
| self.root_folder, | ||
| train=train, | ||
| transform= self.get_transform(32), | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto. |
||
| download=True, | ||
| ) | ||
| elif name == "stl10": | ||
| if train: | ||
| split = "train" | ||
| else: | ||
| split = "test" | ||
| return datasets.STL10( | ||
| self.root_folder, | ||
| split=split, | ||
| transform=self.get_transform(96), | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto. |
||
| download=True, | ||
| ) | ||
| else: | ||
| raise InvalidDatasetSelection() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| import torch | ||
| from torch import nn | ||
| from torchvision import models | ||
|
|
||
| from ..exceptions.exceptions import InvalidBackboneError | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please use absolute paths. |
||
|
|
||
|
|
||
| class PretrainedResNet(nn.Module): | ||
| def __init__(self, base_model, pretrained_model_file, linear_eval=True, num_classes=10): | ||
| super(PretrainedResNet, self).__init__() | ||
|
|
||
| self.pretrained_model_file = pretrained_model_file | ||
|
|
||
| self.resnet_dict = { | ||
| "resnet18": models.resnet18(pretrained=False, num_classes=num_classes), | ||
| "resnet50": models.resnet50(pretrained=False, num_classes=num_classes), | ||
| } | ||
|
|
||
| self.backbone = self._get_basemodel(base_model) | ||
|
|
||
| # load pretrained weights | ||
| log = self._load_pretrained() | ||
|
|
||
| assert log.missing_keys == ["fc.weight", "fc.bias"] | ||
|
|
||
| if linear_eval: | ||
| # freeze all layers but the last fc | ||
| self._freeze_backbone() | ||
| parameters = list(filter(lambda p: p.requires_grad, self.backbone.parameters())) | ||
| assert len(parameters) == 2 # fc.weight, fc.bias | ||
|
|
||
| def _load_pretrained(self): | ||
| checkpoint = torch.load(self.pretrained_model_file, map_location='cpu') | ||
| state_dict = checkpoint["state_dict"] | ||
| for k in list(state_dict.keys()): | ||
| if k.startswith("module.backbone."): | ||
| if not k.startswith("module.backbone.fc"): | ||
| # remove prefix | ||
| state_dict[k[len("module.backbone.") :]] = state_dict[k] | ||
| del state_dict[k] | ||
| log = self.backbone.load_state_dict(state_dict, strict=False) | ||
| return log | ||
|
|
||
|
|
||
| def _freeze_backbone(self): | ||
| # freeze all layers but the last fc | ||
| for name, param in self.backbone.named_parameters(): | ||
| if name not in ["fc.weight", "fc.bias"]: | ||
| param.requires_grad = False | ||
| return | ||
|
|
||
|
|
||
| def _get_basemodel(self, model_name): | ||
| try: | ||
| model = self.resnet_dict[model_name] | ||
| except KeyError: | ||
| raise InvalidBackboneError( | ||
| "Invalid backbone architecture. Check the config file and pass one of: resnet18 or resnet50", | ||
| ) | ||
| else: | ||
| return model | ||
|
|
||
| def forward(self, x): | ||
| return self.backbone(x) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| import os | ||
| from datetime import datetime | ||
|
|
||
| import torch | ||
| from torch.cuda.amp import GradScaler, autocast | ||
|
|
@@ -18,7 +19,15 @@ def __init__(self, *args, **kwargs): | |
| self.optimizer = kwargs["optimizer"] | ||
| self.scheduler = kwargs["scheduler"] | ||
| self.device_id = kwargs["device_id"] | ||
| self.writer = SummaryWriter() | ||
| # Create a directory to save the model checkpoints and logs | ||
| now = datetime.now() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have added logic to run_simCLR.py that does something like this. So you have to either remove these parts, or remove my changes before merging. |
||
| dt_string = now.strftime("%Y_%m_%d_%H_%M") | ||
| log_dir = os.path.join(args.model_dir, args.experiment_name,dt_string) | ||
| try: | ||
| os.makedirs(log_dir) | ||
| except FileExistsError: | ||
| print(f"Directory {log_dir} made by another worker", flush=True) | ||
| self.writer = SummaryWriter(log_dir) | ||
sanaAyrml marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self.criterion = loss.SimCLRContrastiveLoss(self.args.temperature).cuda( | ||
| self.device_id | ||
| ) | ||
|
|
@@ -62,7 +71,6 @@ def train(self, train_loader): | |
| self.scheduler.get_last_lr()[0], | ||
| global_step=n_iter, | ||
| ) | ||
|
|
||
| n_iter += 1 | ||
|
|
||
| # warmup for the first 10 epochs | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| #!/bin/bash | ||
|
|
||
| #SBATCH --job-name=train_sunrgbd | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rename |
||
| #SBATCH --partition=t4v2 | ||
| #SBATCH --nodes=1 | ||
| #SBATCH --gres=gpu:4 | ||
| #SBATCH --ntasks-per-node=4 | ||
| #SBATCH --cpus-per-task=4 | ||
| #SBATCH --mem=100G | ||
| #SBATCH --output=logs/simclr/eval_slurm-%N-%j.out | ||
| #SBATCH --error=logs/simclr/eval_slurm-%N-%j.err | ||
| #SBATCH --qos=m | ||
|
|
||
| PY_ARGS=${@:1} | ||
|
|
||
| # load virtual environment | ||
| source /ssd003/projects/aieng/envs/genssl2/bin/activate | ||
|
|
||
| export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend | ||
| export CUDA_LAUNCH_BLOCKING=1 | ||
|
|
||
| export MASTER_ADDR=$(hostname) | ||
| export MASTER_PORT=45679 | ||
|
|
||
| export PYTHONPATH="." | ||
| nvidia-smi | ||
|
|
||
| pretrained_model_dir="/projects/imagenet_synthetic/train_models" | ||
| experiment_name="simclr/2024_02_23_13_02" | ||
|
|
||
| cd $pretrained_model_dir/$experiment_name | ||
|
|
||
| files=$(ls checkpoint_epoch_*) | ||
|
|
||
| cd "$OLDPWD" | ||
|
|
||
| # Loop through each file and pass it as a parameter to the rest of the script | ||
| for file in $files | ||
| do | ||
| echo "Evaluating: $file" | ||
|
|
||
| # srun execute ntasks-per-node * nodes times | ||
| srun python evaluate_simCLR.py \ | ||
| --distributed_mode \ | ||
| --batch-size=256 \ | ||
| --pretrained_model_dir=$pretrained_model_dir \ | ||
| --experiment_name=$experiment_name \ | ||
| --pretrained_model_name=$file \ | ||
| --linear_evaluation | ||
| # Add your processing logic here | ||
| done | ||
Uh oh!
There was an error while loading. Please reload this page.