diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7d459c0..45f0c42 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,15 +40,6 @@ repos: - id: nbqa-ruff args: [--fix] - - repo: local - hooks: - - id: pytest - name: pytest - entry: python3 -m pytest - language: system - pass_filenames: false - always_run: true - exclude: | (?x)( ^rcdm/| diff --git a/SimCLR/data_aug/contrastive_learning_dataset.py b/SimCLR/datasets/contrastive_learning_dataset.py similarity index 91% rename from SimCLR/data_aug/contrastive_learning_dataset.py rename to SimCLR/datasets/contrastive_learning_dataset.py index 74b3625..3723652 100644 --- a/SimCLR/data_aug/contrastive_learning_dataset.py +++ b/SimCLR/datasets/contrastive_learning_dataset.py @@ -1,11 +1,11 @@ from torchvision import datasets, transforms -from SimCLR.data_aug.gaussian_blur import GaussianBlur -from SimCLR.data_aug.icgan_aug import ICGANInference -from SimCLR.data_aug.icgan_config import get_icgan_config -from SimCLR.data_aug.rcdm_aug import RCDMInference -from SimCLR.data_aug.rcdm_config import get_config -from SimCLR.data_aug.view_generator import ContrastiveLearningViewGenerator +from SimCLR.datasets.data_aug.gaussian_blur import GaussianBlur +from SimCLR.datasets.data_aug.icgan_aug import ICGANInference +from SimCLR.datasets.data_aug.icgan_config import get_icgan_config +from SimCLR.datasets.data_aug.rcdm_aug import RCDMInference +from SimCLR.datasets.data_aug.rcdm_config import get_config +from SimCLR.datasets.view_generator import ContrastiveLearningViewGenerator from SimCLR.exceptions.exceptions import InvalidDatasetSelection @@ -114,3 +114,4 @@ def get_dataset( raise InvalidDatasetSelection() else: return dataset_fn() + \ No newline at end of file diff --git a/SimCLR/datasets/data_aug/center_crop.py b/SimCLR/datasets/data_aug/center_crop.py new file mode 100644 index 0000000..783749c --- /dev/null +++ b/SimCLR/datasets/data_aug/center_crop.py @@ -0,0 +1,44 @@ +import torch.nn.functional as F +import torchvision +import torch + +class CostumeCenterCrop(torch.nn.Module): + def __init__(self, size=None, ratio="1:1"): + super().__init__() + self.size = size + self.ratio = ratio + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + PIL Image or Tensor: Cropped image. + """ + if self.size is None: + if isinstance(img, torch.Tensor): + h, w = img.shape[-2:] + else: + w, h = img.size + ratio = self.ratio.split(":") + ratio = float(ratio[0]) / float(ratio[1]) + # Size must match the ratio while cropping to the edge of the image + ratioed_w = int(h * ratio) + ratioed_h = int(w / ratio) + if w>=h: + if ratioed_h <= h: + size = (ratioed_h, w) + else: + size = (h, ratioed_w) + else: + if ratioed_w <= w: + size = (h, ratioed_w) + else: + size = (ratioed_h, w) + else: + size = self.size + return torchvision.transforms.functional.center_crop(img, size) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size})" + \ No newline at end of file diff --git a/SimCLR/data_aug/gaussian_blur.py b/SimCLR/datasets/data_aug/gaussian_blur.py similarity index 100% rename from SimCLR/data_aug/gaussian_blur.py rename to SimCLR/datasets/data_aug/gaussian_blur.py diff --git a/SimCLR/data_aug/icgan_aug.py b/SimCLR/datasets/data_aug/icgan_aug.py similarity index 100% rename from SimCLR/data_aug/icgan_aug.py rename to SimCLR/datasets/data_aug/icgan_aug.py diff --git a/SimCLR/data_aug/icgan_config.py b/SimCLR/datasets/data_aug/icgan_config.py similarity index 100% rename from SimCLR/data_aug/icgan_config.py rename to SimCLR/datasets/data_aug/icgan_config.py diff --git a/SimCLR/data_aug/rcdm_aug.py b/SimCLR/datasets/data_aug/rcdm_aug.py similarity index 90% rename from SimCLR/data_aug/rcdm_aug.py rename to SimCLR/datasets/data_aug/rcdm_aug.py index 7269022..4cb0ed2 100644 --- a/SimCLR/data_aug/rcdm_aug.py +++ b/SimCLR/datasets/data_aug/rcdm_aug.py @@ -21,7 +21,7 @@ def __init__(self, config, device_id): # Load SSL model self.ssl_model = ( - get_model(self.config.type_model, self.config.use_head) + get_model(self.config.type_model, self.config.use_head, self.config.pretrained_models_dir) .cuda(self.device_id) .eval() ) @@ -50,7 +50,7 @@ def __init__(self, config, device_id): if self.config.model_path == "": trained_model = get_dict_rcdm_model( - self.config.type_model, self.config.use_head + self.config.type_model, self.config.use_head, self.config.pretrained_models_dir ) else: trained_model = torch.load(self.config.model_path, map_location="cpu") @@ -63,7 +63,6 @@ def preprocess_input_image(self, input_image, size=224): data_utils.CenterCropLongEdge(), transforms.Resize((size, size)), transforms.ToTensor(), - transforms.Normalize(self.config.norm_mean, self.config.norm_std), ] ) tensor_image = transform_list(input_image) @@ -89,9 +88,8 @@ def __call__(self, img): if not self.config.use_ddim else self.diffusion.ddim_sample_loop ) - - img = img.unsqueeze(0).repeat(1, 1, 1, 1) img = self.preprocess_input_image(img).cuda(self.device_id) + img = img.repeat(1, 1, 1, 1) model_kwargs = {} with torch.no_grad(): @@ -104,5 +102,4 @@ def __call__(self, img): model_kwargs=model_kwargs, ) - print("Sampling completed!") return sample.squeeze(0) diff --git a/SimCLR/data_aug/rcdm_config.py b/SimCLR/datasets/data_aug/rcdm_config.py similarity index 85% rename from SimCLR/data_aug/rcdm_config.py rename to SimCLR/datasets/data_aug/rcdm_config.py index 4c646d9..9d08647 100644 --- a/SimCLR/data_aug/rcdm_config.py +++ b/SimCLR/datasets/data_aug/rcdm_config.py @@ -1,29 +1,31 @@ -import ml_collections - - -def get_config(): - config = ml_collections.ConfigDict() - config.image_size = 128 # The size of the images to generate. - config.class_cond = False # If true, use class conditional generation. - config.type_model = "simclr" # Type of model to use (e.g., simclr, dino). - config.use_head = False # If true, use the projector/head for SSL representation. - config.model_path = "" # Replace with the path to your model if you have one. - config.use_ddim = False # If true, use DDIM sampler. - config.no_shared = True # If false, enables squeeze and excitation. - config.clip_denoised = True # If true, clip denoised images. - config.attention_resolutions = "32,16,8" # Resolutions to use for attention layers. - config.diffusion_steps = 100 # Number of diffusion steps. - config.learn_sigma = True # If true, learn the noise level. - config.noise_schedule = "linear" # Type of noise schedule (e.g., linear). - config.num_channels = 256 # Number of channels in the model. - config.num_heads = 4 # Number of attention heads. - config.num_res_blocks = 2 # Number of residual blocks. - config.resblock_updown = True # If true, use up/down sampling in resblocks. - config.use_fp16 = False # If true, use 16-bit floating point precision. - config.use_scale_shift_norm = True # If true, use scale-shift normalization. - config.ssl_image_size = 224 # Size of the input images for the SSL model. - config.ssl_image_channels = ( - 3 # Number of channels of the input images for the SSL model. - ) - - return config +import ml_collections + + +def get_config(): + config = ml_collections.ConfigDict() + config.image_size = 128 # The size of the images to generate. + config.class_cond = False # If true, use class conditional generation. + config.pretrained_models_dir = "/ssd003/projects/aieng/genssl" # Path to the directory containing the model. + config.type_model = "simclr" # Type of model to use (e.g., simclr, dino). + config.use_head = False # If true, use the projector/head for SSL representation. + config.model_path = "" # Replace with the path to your model if you have one. + config.use_ddim = True # If true, use DDIM sampler. + config.no_shared = True # If false, enables squeeze and excitation. + config.clip_denoised = True # If true, clip denoised images. + config.attention_resolutions = "32,16,8" # Resolutions to use for attention layers. + config.diffusion_steps = 100 # Number of diffusion steps. + config.learn_sigma = True # If true, learn the noise level. + config.noise_schedule = "linear" # Type of noise schedule (e.g., linear). + config.num_channels = 256 # Number of channels in the model. + config.num_heads = 4 # Number of attention heads. + config.num_res_blocks = 2 # Number of residual blocks. + config.resblock_updown = True # If true, use up/down sampling in resblocks. + config.use_fp16 = False # If true, use 16-bit floating point precision. + config.use_scale_shift_norm = True # If true, use scale-shift normalization. + config.ssl_image_size = 224 # Size of the input images for the SSL model. + config.ssl_image_channels = ( + 3 # Number of channels of the input images for the SSL model. + ) + config.timestep_respacing = "ddim2" # Type of timestep respacing (e.g., ddim25). + + return config diff --git a/SimCLR/datasets/supervised_dataset.py b/SimCLR/datasets/supervised_dataset.py new file mode 100644 index 0000000..51efb69 --- /dev/null +++ b/SimCLR/datasets/supervised_dataset.py @@ -0,0 +1,57 @@ +from torchvision import datasets, transforms +from torchvision.transforms import transforms + +from SimCLR.exceptions.exceptions import InvalidDatasetSelection +from SimCLR.datasets.data_aug.center_crop import CostumeCenterCrop + +class SupervisedDataset: + def __init__(self, root_folder): + self.root_folder = root_folder + + @staticmethod + def get_transform(size): + """Return a set of simple transformations for supervised learning. + + Args: + size (int): Image size. + """ + transform_list = [ + CostumeCenterCrop(), + transforms.Resize((size, size)), + transforms.ToTensor(), + ] + + return transforms.Compose(transform_list) + + + def get_dataset(self, name, train = True): + if name == "imagenet": + if train: + split = "train" + else: + split = "val" + return datasets.ImageNet( + self.root_folder, + split=split, + transform=self.get_transform(224), + ) + elif name == "cifar10": + return datasets.CIFAR10( + self.root_folder, + train=train, + transform= self.get_transform(32), + download=True, + ) + elif name == "stl10": + if train: + split = "train" + else: + split = "test" + return datasets.STL10( + self.root_folder, + split=split, + transform=self.get_transform(96), + download=True, + ) + else: + raise InvalidDatasetSelection() \ No newline at end of file diff --git a/SimCLR/data_aug/view_generator.py b/SimCLR/datasets/view_generator.py similarity index 100% rename from SimCLR/data_aug/view_generator.py rename to SimCLR/datasets/view_generator.py diff --git a/SimCLR/models/resnet_pretrained.py b/SimCLR/models/resnet_pretrained.py new file mode 100644 index 0000000..bc37e85 --- /dev/null +++ b/SimCLR/models/resnet_pretrained.py @@ -0,0 +1,64 @@ +import torch +from torch import nn +from torchvision import models + +from ..exceptions.exceptions import InvalidBackboneError + + +class PretrainedResNet(nn.Module): + def __init__(self, base_model, pretrained_model_file, linear_eval=True, num_classes=10): + super(PretrainedResNet, self).__init__() + + self.pretrained_model_file = pretrained_model_file + + self.resnet_dict = { + "resnet18": models.resnet18(pretrained=False, num_classes=num_classes), + "resnet50": models.resnet50(pretrained=False, num_classes=num_classes), + } + + self.backbone = self._get_basemodel(base_model) + + # load pretrained weights + log = self._load_pretrained() + + assert log.missing_keys == ["fc.weight", "fc.bias"] + + if linear_eval: + # freeze all layers but the last fc + self._freeze_backbone() + parameters = list(filter(lambda p: p.requires_grad, self.backbone.parameters())) + assert len(parameters) == 2 # fc.weight, fc.bias + + def _load_pretrained(self): + checkpoint = torch.load(self.pretrained_model_file, map_location='cpu') + state_dict = checkpoint["state_dict"] + for k in list(state_dict.keys()): + if k.startswith("module.backbone."): + if not k.startswith("module.backbone.fc"): + # remove prefix + state_dict[k[len("module.backbone.") :]] = state_dict[k] + del state_dict[k] + log = self.backbone.load_state_dict(state_dict, strict=False) + return log + + + def _freeze_backbone(self): + # freeze all layers but the last fc + for name, param in self.backbone.named_parameters(): + if name not in ["fc.weight", "fc.bias"]: + param.requires_grad = False + return + + + def _get_basemodel(self, model_name): + try: + model = self.resnet_dict[model_name] + except KeyError: + raise InvalidBackboneError( + "Invalid backbone architecture. Check the config file and pass one of: resnet18 or resnet50", + ) + else: + return model + + def forward(self, x): + return self.backbone(x) diff --git a/SimCLR/simclr.py b/SimCLR/simclr.py index 28e34f3..e03ee60 100644 --- a/SimCLR/simclr.py +++ b/SimCLR/simclr.py @@ -1,4 +1,5 @@ import os +from datetime import datetime import torch from torch.cuda.amp import GradScaler, autocast @@ -18,7 +19,15 @@ def __init__(self, *args, **kwargs): self.optimizer = kwargs["optimizer"] self.scheduler = kwargs["scheduler"] self.device_id = kwargs["device_id"] - self.writer = SummaryWriter() + # Create a directory to save the model checkpoints and logs + now = datetime.now() + dt_string = now.strftime("%Y_%m_%d_%H_%M") + log_dir = os.path.join(args.model_dir, args.experiment_name,dt_string) + try: + os.makedirs(log_dir) + except FileExistsError: + print(f"Directory {log_dir} made by another worker", flush=True) + self.writer = SummaryWriter(log_dir) self.criterion = loss.SimCLRContrastiveLoss(self.args.temperature).cuda( self.device_id ) @@ -62,7 +71,6 @@ def train(self, train_loader): self.scheduler.get_last_lr()[0], global_step=n_iter, ) - n_iter += 1 # warmup for the first 10 epochs diff --git a/eval_simclr.slrm b/eval_simclr.slrm new file mode 100644 index 0000000..6c4b627 --- /dev/null +++ b/eval_simclr.slrm @@ -0,0 +1,51 @@ +#!/bin/bash + +#SBATCH --job-name=train_sunrgbd +#SBATCH --partition=t4v2 +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --ntasks-per-node=4 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=100G +#SBATCH --output=logs/simclr/eval_slurm-%N-%j.out +#SBATCH --error=logs/simclr/eval_slurm-%N-%j.err +#SBATCH --qos=m + +PY_ARGS=${@:1} + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +export CUDA_LAUNCH_BLOCKING=1 + +export MASTER_ADDR=$(hostname) +export MASTER_PORT=45679 + +export PYTHONPATH="." +nvidia-smi + +pretrained_model_dir="/projects/imagenet_synthetic/train_models" +experiment_name="simclr/2024_02_23_13_02" + +cd $pretrained_model_dir/$experiment_name + +files=$(ls checkpoint_epoch_*) + +cd "$OLDPWD" + +# Loop through each file and pass it as a parameter to the rest of the script +for file in $files +do + echo "Evaluating: $file" + + # srun execute ntasks-per-node * nodes times + srun python evaluate_simCLR.py \ + --distributed_mode \ + --batch-size=256 \ + --pretrained_model_dir=$pretrained_model_dir \ + --experiment_name=$experiment_name \ + --pretrained_model_name=$file \ + --linear_evaluation + # Add your processing logic here +done \ No newline at end of file diff --git a/evaluate_simCLR.py b/evaluate_simCLR.py new file mode 100644 index 0000000..cc1cca3 --- /dev/null +++ b/evaluate_simCLR.py @@ -0,0 +1,313 @@ +import argparse +import random +from functools import partial + +import os +import torch +from torch.nn.parallel import DistributedDataParallel as DDP # noqa: N817 +from torch.utils.data.distributed import DistributedSampler +from torchvision import models +from tqdm import tqdm + +from SimCLR import distributed as dist_utils +from SimCLR.datasets.supervised_dataset import SupervisedDataset +from SimCLR.models.resnet_pretrained import PretrainedResNet + + +model_names = sorted( + name + for name in models.__dict__ + if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) +) + +parser = argparse.ArgumentParser(description="PyTorch SimCLR") +parser.add_argument( + "-data", + metavar="DIR", + default="/scratch/ssd004/datasets/imagenet256", + help="path to dataset, for imagenet: /scratch/ssd004/datasets/imagenet256 ", +) +parser.add_argument( + "-dataset-name", + default="imagenet", + help="dataset-name", + choices=["stl10", "cifar10", "imagenet"], +) +parser.add_argument( + "-a", + "--arch", + metavar="ARCH", + default="resnet18", + choices=model_names, + help="model architecture: " + " | ".join(model_names) + " (default: resnet18)", +) +parser.add_argument( + "-j", + "--num_workers", + default=4, + type=int, + metavar="N", + help="number of data loading workers", +) +parser.add_argument( + "--epochs", + default=100, + type=int, + metavar="N", + help="number of total epochs to run", +) +parser.add_argument( + "-b", + "--batch-size", + default=64, + type=int, + metavar="N", + help="mini-batch size (default: 256), this is the total " + "batch size of all GPUs on the current node when " + "using Data Parallel or Distributed Data Parallel", +) +parser.add_argument( + "--lr", + "--learning-rate", + default=0.0003, + type=float, + metavar="LR", + help="initial learning rate", + dest="lr", +) +parser.add_argument( + "--wd", + "--weight-decay", + default=8e-4, + type=float, + metavar="W", + help="weight decay (default: 1e-4)", + dest="weight_decay", +) +parser.add_argument( + "--seed", + default=42, + type=int, + help="seed for initializing training. ", +) +parser.add_argument( + "--distributed_mode", + action="store_true", + help="Enable distributed training", +) +parser.add_argument("--distributed_launcher", default="slurm") +parser.add_argument("--distributed_backend", default="nccl") +parser.add_argument( + "--pretrained_model_dir", + default=None, + help="Path to the pretrained model directory.") +parser.add_argument( + "--pretrained_model_name", + default=None, + help="Name of pretrained model.") +parser.add_argument( + "--experiment_name", + default=None, + help="Name of the experiment.") +parser.add_argument( + "--linear_evaluation", + action="store_true", + help="Whether or not to evaluate the linear evaluation of the model.") +parser.add_argument( + "--enable_checkpointing", + action="store_true", + help="Whether or not to enable checkpointing of the model.") + + +def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int) -> None: + """Initialize worker processes with a random seed. + + Parameters + ---------- + worker_id : int + ID of the worker process. + num_workers : int + Total number of workers that will be initialized. + rank : int + The rank of the current process. + seed : int + A random seed used determine the worker seed. + """ + worker_seed = num_workers * rank + worker_id + seed + torch.manual_seed(worker_seed) + random.seed(worker_seed) + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + +def save_checkpoint(state, filename="checkpoint.pth.tar"): + torch.save(state, filename) + + +def main(): + args = parser.parse_args() + print(args) + + torch.multiprocessing.set_start_method("spawn") + + if args.distributed_mode: + dist_utils.init_distributed_mode( + launcher=args.distributed_launcher, + backend=args.distributed_backend, + ) + device_id = torch.cuda.current_device() + else: + device_id = None + + dataset = SupervisedDataset(args.data) + train_dataset = dataset.get_dataset( + name = args.dataset_name, + train=True, + ) + test_dataset = dataset.get_dataset( + name = args.dataset_name, + train=False, + ) + train_sampler = None + test_sampler = None + + if dist_utils.is_dist_avail_and_initialized() and args.distributed_mode: + train_sampler = DistributedSampler( + train_dataset, + seed=args.seed, + drop_last=True, + ) + test_sampler = DistributedSampler( + test_dataset, + seed=args.seed, + drop_last=False, + ) + + init_fn = partial( + worker_init_fn, + num_workers=args.num_workers, + rank=dist_utils.get_rank(), + seed=args.seed, + ) + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.batch_size, + shuffle=(train_sampler is None), + sampler=train_sampler, + num_workers=args.num_workers, + worker_init_fn=init_fn, + pin_memory=False, + drop_last=True, + ) + test_loader = torch.utils.data.DataLoader( + test_dataset, + batch_size=args.batch_size, + shuffle=(test_sampler is None), + sampler=test_sampler, + num_workers=args.num_workers, + worker_init_fn=init_fn, + pin_memory=False, + drop_last=False, + ) + if args.dataset_name == "cifar10": + num_classes = 10 + elif args.dataset_name == "stl10": + num_classes = 10 + elif args.dataset_name == "imagenet": + num_classes = 1000 + + model = PretrainedResNet( + base_model=args.arch, + pretrained_model_file = os.path.join(args.pretrained_model_dir, args.experiment_name, args.pretrained_model_name), + linear_eval=args.linear_evaluation, + num_classes=num_classes) + + if args.distributed_mode and dist_utils.is_dist_avail_and_initialized(): + # set the single device scope, otherwise DistributedDataParallel will + # use all available devices + torch.cuda.set_device(device_id) + model = model.cuda(device_id) + model = DDP(model, device_ids=[device_id]) + else: + model = model.cuda() + + optimizer = torch.optim.Adam( + model.parameters(), + lr=args.lr, + weight_decay=args.weight_decay, + ) + + criterion = torch.nn.CrossEntropyLoss().cuda(device_id) + + log_dir = os.path.join(args.pretrained_model_dir, args.experiment_name) + + for epoch_counter in tqdm(range(args.epochs), desc="Epoch Progress"): + if dist_utils.is_dist_avail_and_initialized(): + train_loader.sampler.set_epoch(epoch_counter) + top1_train_accuracy = 0 + counter = 0 + for x_batch, y_batch in tqdm(train_loader, desc="Training Progress"): + x_batch = x_batch.cuda(device_id) + y_batch = y_batch.cuda(device_id) + + logits = model(x_batch) + loss = criterion(logits, y_batch) + top1 = accuracy(logits, y_batch, topk=(1,)) + top1_train_accuracy += top1[0] + + optimizer.zero_grad() + loss.backward() + optimizer.step() + counter += 1 + + top1_train_accuracy /= counter + top1_accuracy = 0 + top5_accuracy = 0 + counter = 0 + for x_batch, y_batch in tqdm(test_loader, desc="Evaluation Progress"): + x_batch = x_batch.cuda(device_id) + y_batch = y_batch.cuda(device_id) + + logits = model(x_batch) + + top1, top5 = accuracy(logits, y_batch, topk=(1, 5)) + top1_accuracy += top1[0] + top5_accuracy += top5[0] + counter += 1 + + top1_accuracy /= counter + top5_accuracy /= counter + print( + f"Epoch {epoch_counter}\t Top1 Train accuracy {top1_train_accuracy.item()}\tTop1 Test accuracy: {top1_accuracy.item()}\tTop5 test acc: {top5_accuracy.item()}", + flush=True, + ) + if args.enable_checkpointing: + checkpoint_name = "checkpoint_supervised_epoch_{:04d}.pth.tar".format(epoch_counter) + save_checkpoint( + { + "n_epoch": epoch_counter, + "arch": args.arch, + "state_dict": model.state_dict(), + "optimizer": optimizer.state_dict(), + }, + filename=os.path.join(log_dir, checkpoint_name), + ) + + + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/rcdm/guided_diffusion_rcdm/get_rcdm_models.py b/rcdm/guided_diffusion_rcdm/get_rcdm_models.py index 545e6e3..803c431 100644 --- a/rcdm/guided_diffusion_rcdm/get_rcdm_models.py +++ b/rcdm/guided_diffusion_rcdm/get_rcdm_models.py @@ -4,69 +4,68 @@ import torch.nn as nn from torchvision import models as torchvision_models - -def get_dict_rcdm_model(model="dino", use_head=False): +def get_dict_rcdm_model(model="dino", use_head=False, model_dir='./'): """ Download checkpoints of RCDM. """ if model == "supervised": trained_model = torch.hub.load_state_dict_from_url( - "https://dl.fbaipublicfiles.com/rcdm/rcdm_ema_supervised.pt", - map_location="cpu", - ) + "https://dl.fbaipublicfiles.com/rcdm/rcdm_ema_supervised.pt", + map_location="cpu", + model_dir=model_dir) return trained_model elif model == "simclr": if use_head: trained_model = torch.hub.load_state_dict_from_url( - "https://dl.fbaipublicfiles.com/rcdm/rcdm_ema_simclr_head.pt", - map_location="cpu", - ) + "https://dl.fbaipublicfiles.com/rcdm/rcdm_ema_simclr_head.pt", + map_location="cpu", + model_dir=model_dir) else: trained_model = torch.hub.load_state_dict_from_url( - "https://dl.fbaipublicfiles.com/rcdm/rcdm_ema_simclr_trunk.pt", - map_location="cpu", - ) + "https://dl.fbaipublicfiles.com/rcdm/rcdm_ema_simclr_trunk.pt", + map_location="cpu", + model_dir=model_dir) return trained_model elif model == "barlow": if use_head: trained_model = torch.hub.load_state_dict_from_url( - "https://dl.fbaipublicfiles.com/rcdm/rcdm_ema_barlow_head.pt", - map_location="cpu", - ) + "https://dl.fbaipublicfiles.com/rcdm/rcdm_ema_barlow_head.pt", + map_location="cpu", + model_dir=model_dir) else: trained_model = torch.hub.load_state_dict_from_url( - "https://dl.fbaipublicfiles.com/rcdm/rcdm_ema_barlow_trunk.pt", - map_location="cpu", - ) + "https://dl.fbaipublicfiles.com/rcdm/rcdm_ema_barlow_trunk.pt", + map_location="cpu", + model_dir=model_dir) return trained_model elif model == "vicreg": if use_head: trained_model = torch.hub.load_state_dict_from_url( - "https://dl.fbaipublicfiles.com/rcdm/rcdm_ema_vicreg_head.pt", - map_location="cpu", - ) + "https://dl.fbaipublicfiles.com/rcdm/rcdm_ema_vicreg_head.pt", + map_location="cpu", + model_dir=model_dir) else: trained_model = torch.hub.load_state_dict_from_url( - "https://dl.fbaipublicfiles.com/rcdm/rcdm_ema_vicreg_trunk.pt", - map_location="cpu", - ) + "https://dl.fbaipublicfiles.com/rcdm/rcdm_ema_vicreg_trunk.pt", + map_location="cpu", + model_dir=model_dir) return trained_model elif model == "dino": if use_head: trained_model = torch.hub.load_state_dict_from_url( - "https://dl.fbaipublicfiles.com/rcdm/rcdm_ema_dino_head.pt", - map_location="cpu", - ) + "https://dl.fbaipublicfiles.com/rcdm/rcdm_ema_dino_head.pt", + map_location="cpu", + model_dir=model_dir) else: trained_model = torch.hub.load_state_dict_from_url( - "https://dl.fbaipublicfiles.com/rcdm/rcdm_ema_dino_trunk.pt", - map_location="cpu", - ) + "https://dl.fbaipublicfiles.com/rcdm/rcdm_ema_dino_trunk.pt", + map_location="cpu", + model_dir=model_dir) return trained_model else: diff --git a/rcdm/guided_diffusion_rcdm/get_ssl_models.py b/rcdm/guided_diffusion_rcdm/get_ssl_models.py index 810bc02..497b32e 100644 --- a/rcdm/guided_diffusion_rcdm/get_ssl_models.py +++ b/rcdm/guided_diffusion_rcdm/get_ssl_models.py @@ -106,8 +106,7 @@ def forward(self, x): x = torch.nn.functional.normalize(x, dim=-1, p=2).detach() return x - -def get_model(model="dino", use_head=False): +def get_model(model="dino", use_head=False, model_dir='./'): """ Select a model that will be used to compute the embeddings needed by RCDM. You can use any kind of model, ConvNets/MLPs, or VITs. @@ -128,9 +127,9 @@ def get_model(model="dino", use_head=False): use_bn=True, ) pretrained_model = torch.hub.load_state_dict_from_url( - "https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain_full_checkpoint.pth", - map_location="cpu", - ) + "https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain_full_checkpoint.pth", + map_location="cpu", + model_dir=model_dir) pretrained_model = pretrained_model["teacher"] if "state_dict" in pretrained_model: pretrained_model = pretrained_model["state_dict"] @@ -160,9 +159,9 @@ def get_model(model="dino", use_head=False): embedding_model = torchvision_models.resnet50() embedding_model.fc = nn.Identity() pretrained_model_base = torch.hub.load_state_dict_from_url( - "https://dl.fbaipublicfiles.com/vissl/model_zoo/simclr_rn50_1000ep_simclr_8node_resnet_16_07_20.afe428c7/model_final_checkpoint_phase999.torch", - map_location="cpu", - ) + "https://dl.fbaipublicfiles.com/vissl/model_zoo/simclr_rn50_1000ep_simclr_8node_resnet_16_07_20.afe428c7/model_final_checkpoint_phase999.torch", + map_location="cpu", + model_dir=model_dir) # Load trunk pretrained_model = pretrained_model_base["classy_state_dict"]["base_model"][ "model" @@ -195,6 +194,7 @@ def get_model(model="dino", use_head=False): pretrained_model_base = torch.hub.load_state_dict_from_url( "https://dl.fbaipublicfiles.com/vissl/model_zoo/barlow_twins/barlow_twins_32gpus_4node_imagenet1k_1000ep_resnet50.torch", map_location="cpu", + model_dir=model_dir, ) # Load trunk pretrained_model = pretrained_model_base["classy_state_dict"]["base_model"][ @@ -232,6 +232,7 @@ def get_model(model="dino", use_head=False): pretrained_model_base = torch.hub.load_state_dict_from_url( "https://dl.fbaipublicfiles.com/vicreg/resnet50_fullckpt.pth", map_location="cpu", + model_dir=model_dir, ) embedding_model.classifier = nn.Identity() embedding_model.projector = Projector(emb=8192) diff --git a/run_simCLR.py b/run_simCLR.py index 59a6a85..c307027 100644 --- a/run_simCLR.py +++ b/run_simCLR.py @@ -9,7 +9,7 @@ from torchvision import models from SimCLR import distributed as dist_utils -from SimCLR.data_aug.contrastive_learning_dataset import ContrastiveLearningDataset +from SimCLR.datasets.contrastive_learning_dataset import ContrastiveLearningDataset from SimCLR.models.resnet_simclr import ResNetSimCLR from SimCLR.simclr import SimCLR from torch.utils.data import Subset @@ -139,6 +139,8 @@ ) parser.add_argument("--distributed_launcher", default="slurm") parser.add_argument("--distributed_backend", default="nccl") +parser.add_argument("--model_dir", default="model_checkpoints") +parser.add_argument("--experiment_name", default="simclr") parser.add_argument( "--subset_fraction", default=1.0, @@ -171,6 +173,7 @@ def main(): args = parser.parse_args() print(args) + # Set the start method to spawn for distributed training torch.multiprocessing.set_start_method("spawn") assert ( @@ -248,7 +251,6 @@ def main(): scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=len(train_loader), eta_min=0, last_epoch=-1 ) - simclr = SimCLR( model=model, optimizer=optimizer, @@ -260,4 +262,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/tests/test_evaluation.py b/tests/test_evaluation.py new file mode 100644 index 0000000..b0cfb31 --- /dev/null +++ b/tests/test_evaluation.py @@ -0,0 +1,34 @@ +import pytest +import torch +from evaluate_simCLR import accuracy + +def test_accuracy()-> None: + # Create sample data + output = torch.tensor([[0.1, 0.5, 0.3], [0.2, 0.6, 0.2]]) + target = torch.tensor([1, 2]) + topk = (1,) + + # Calculate accuracy + res = accuracy(output, target, topk=topk) + + # Check if the result matches the expected accuracy + expected_accuracy = [50.0] + assert res == expected_accuracy + +def test_accuracy_topk_5(): + # Create sample data + output = torch.tensor([[0.1, 0.5, 0.3, 0.1, 0.4, 0.5, 0.2, 0.3, 0.1, 0.9], + [0.2, 0.6, 0.2, 0.1, 0.3, 0.6, 0.2, 0.4, 0.1, 0.8], + [0.3, 0.4, 0.3, 0.2, 0.5, 0.3, 0.1, 0.7, 0.2, 0.6], + [0.4, 0.3, 0.3, 0.5, 0.6, 0.1, 0.2, 0.8, 0.1, 0.7]]) + target = torch.tensor([6, 7, 8, 9]) # Targets that are not in the top 5 + topk = (5,) + + # Calculate accuracy + res = accuracy(output, target, topk=topk) + print(res) + + # Check if the result matches the expected accuracy + # In this case, the expected accuracy is 25.0 for all samples + expected_accuracy = [50.0] + assert res == expected_accuracy \ No newline at end of file diff --git a/train_simclr.slrm b/train_simclr.slrm index 477a1fb..0a352eb 100644 --- a/train_simclr.slrm +++ b/train_simclr.slrm @@ -1,21 +1,22 @@ #!/bin/bash #SBATCH --job-name=train_sunrgbd -#SBATCH --partition=t4v2 -#SBATCH --time=12:00:00 +#SBATCH --partition=a100 #SBATCH --nodes=1 #SBATCH --gres=gpu:4 -#SBATCH --ntasks-per-node=4 +#SBATCH --ntasks-per-node=1 #SBATCH --cpus-per-task=4 -#SBATCH --mem-per-cpu=2G -#SBATCH --output=slurm-%N-%j.out +#SBATCH --mem=100G +#SBATCH --output=logs/simclr/slurm-%N-%j.out +#SBATCH --error=logs/simclr/slurm-%N-%j.err +#SBATCH --qos=a100_arashaf PY_ARGS=${@:1} # load virtual environment source /ssd003/projects/aieng/envs/genssl2/bin/activate -export NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend export CUDA_LAUNCH_BLOCKING=1 export MASTER_ADDR=$(hostname) @@ -24,9 +25,10 @@ export MASTER_PORT=45679 export PYTHONPATH="." nvidia-smi -# “srun” executes the script times -srun python run_simCLR.py \ +# srun execute ntasks-per-node * nodes times +srun pythong run_simCLR.py \ --fp16-precision \ --distributed_mode \ ---batch-size=4 \ ---icgan_augmentation +--batch-size=256 \ +--model_dir="/projects/imagenet_synthetic/train_models" \ +--experiment_name="simclr" \ No newline at end of file