diff --git a/.gitignore b/.gitignore index a0b56cf..656e049 100644 --- a/.gitignore +++ b/.gitignore @@ -130,3 +130,6 @@ dmypy.json # pycharm .idea/ + +# Trained models +trained_models/ diff --git a/SimCLR/data_aug/imagenet_synthetic_dataset.py b/SimCLR/data_aug/imagenet_synthetic_dataset.py index 66c245e..7c05a6c 100644 --- a/SimCLR/data_aug/imagenet_synthetic_dataset.py +++ b/SimCLR/data_aug/imagenet_synthetic_dataset.py @@ -3,6 +3,7 @@ import os import random +import torch from PIL import Image from torchvision import datasets, transforms @@ -37,6 +38,7 @@ def __init__( imagenet_synthetic_root, index_min=0, index_max=9, + generative_augmentation_prob=None, load_one_real_image=False, split="train", ): @@ -48,6 +50,7 @@ def __init__( self.imagenet_synthetic_root = imagenet_synthetic_root self.index_min = index_min self.index_max = index_max + self.generative_augmentation_prob = generative_augmentation_prob self.load_one_real_image = load_one_real_image self.synthetic_transforms = _get_simclr_transforms(size=224) self.real_transforms = _get_simclr_transforms(size=224, random_crop=True) @@ -62,21 +65,37 @@ def _synthetic_image(filename): filename_parent_dir = filename.split("/")[-2] image_path = os.path.join( self.imagenet_synthetic_root, - # self.split, + self.split, filename_parent_dir, filename_and_extension.split(".")[0] + f"_{rand_int}.JPEG", ) return Image.open(image_path).convert("RGB") - if self.load_one_real_image: - image1 = self.loader(os.path.join(self.root, imagenet_filename)) - image1 = self.real_transforms(image1) - else: - image1 = _synthetic_image(imagenet_filename) - image1 = self.synthetic_transforms(image1) + if self.generative_augmentation_prob is not None: + if torch.rand(1) < self.generative_augmentation_prob: + # Generate a synthetic image. + image1 = _synthetic_image(imagenet_filename) + image1 = self.synthetic_transforms(image1) + else: + image1 = self.loader(os.path.join(self.root, imagenet_filename)) + image1 = self.real_transforms(image1) - # image2 is always synthetic. - image2 = _synthetic_image(imagenet_filename) - image2 = self.synthetic_transforms(image2) + if torch.rand(1) < self.generative_augmentation_prob: + # Generate another synthetic image. + image2 = _synthetic_image(imagenet_filename) + image2 = self.synthetic_transforms(image2) + else: + image2 = self.loader(os.path.join(self.root, imagenet_filename)) + image2 = self.real_transforms(image2) + else: + if self.load_one_real_image: + image1 = self.loader(os.path.join(self.root, imagenet_filename)) + image1 = self.real_transforms(image1) + else: + image1 = _synthetic_image(imagenet_filename) + image1 = self.synthetic_transforms(image1) + # image2 is always synthetic. + image2 = _synthetic_image(imagenet_filename) + image2 = self.synthetic_transforms(image2) return {"view1": image1, "view2": image2}, label diff --git a/SimCLR/simclr.py b/SimCLR/simclr.py index 8178066..39dbc31 100644 --- a/SimCLR/simclr.py +++ b/SimCLR/simclr.py @@ -23,6 +23,17 @@ def __init__(self, *args, **kwargs): self.device_id, ) self.checkpoint_dir = self.args.checkpoint_dir + self.start_epoch = 0 + + if self.args.last_checkpoint: + checkpoint = torch.load(self.args.last_checkpoint) + self.model.load_state_dict(checkpoint["state_dict"]) + self.optimizer.load_state_dict(checkpoint["optimizer"]) + # Start from the next epoch. + self.start_epoch = checkpoint["epoch"] + 1 + print( + f"Checkpoint loaded. Resuming training from epoch: {self.start_epoch}" + ) def train(self, train_loader): scaler = GradScaler(enabled=self.args.fp16_precision) @@ -32,9 +43,12 @@ def train(self, train_loader): print(f"Log dir: {self.writer.log_dir}") n_iter = 0 - print(f"Start SimCLR training for {self.args.epochs} epochs.") + print( + f"Start SimCLR training for {self.args.epochs} epochs starting from {self.start_epoch}." + ) - for epoch_counter in tqdm(range(self.args.epochs), desc="Training Progress"): + train_range = range(self.start_epoch, self.args.epochs) + for epoch_counter in tqdm(train_range, desc="Training Progress"): if dist_utils.is_dist_avail_and_initialized(): train_loader.sampler.set_epoch(epoch_counter) for images, _ in tqdm(train_loader): @@ -77,7 +91,7 @@ def train(self, train_loader): checkpoint_file = os.path.join(self.checkpoint_dir, checkpoint_name) save_checkpoint( { - "epoch": self.args.epochs, + "epoch": epoch_counter, "arch": self.args.arch, "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), diff --git a/eval_scripts/food101/eval_food101_original_simsiam_100.slrm b/eval_scripts/food101/eval_food101_original_simsiam_100.slrm new file mode 100644 index 0000000..7f38bfc --- /dev/null +++ b/eval_scripts/food101/eval_food101_original_simsiam_100.slrm @@ -0,0 +1,39 @@ +#!/bin/bash + +#SBATCH --job-name="simsiam_eval" +#SBATCH --partition=t4v2 +#SBATCH --account=deadline +#SBATCH --qos=deadline +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --time=36:00:00 +#SBATCH --ntasks-per-node=4 +#SBATCH --cpus-per-task=4 +#SBATCH --mem-per-cpu=8G +#SBATCH --output=slurm-%j.out + + +PY_ARGS=${@:1} + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +export CUDA_LAUNCH_BLOCKING=1 + +export MASTER_ADDR=$(hostname) +export MASTER_PORT=45679 + +export PYTHONPATH="." +nvidia-smi + +# “srun” executes the script times +srun python simsiam/linear_eval_downstream_datasets.py \ +--data_dir="/projects/imagenet_synthetic/fereshteh_datasets/" \ +--checkpoint_dir="/projects/imagenet_synthetic/model_checkpoints/food101/evaluate_original"\ +--arch="resnet50" \ +--distributed_mode \ +--batch-size=1024 \ +--lars \ +--dataset_name="food101" \ +--pretrained_checkpoint="/projects/imagenet_synthetic/model_checkpoints/_original_simsiam/checkpoint_0099.pth.tar" diff --git a/eval_scripts/food101/eval_food101_simsiam_baseline_100.slrm b/eval_scripts/food101/eval_food101_simsiam_baseline_100.slrm new file mode 100644 index 0000000..1954f69 --- /dev/null +++ b/eval_scripts/food101/eval_food101_simsiam_baseline_100.slrm @@ -0,0 +1,40 @@ +#!/bin/bash + +#SBATCH --job-name="simsiam_eval" +#SBATCH --partition=t4v2 +#SBATCH --account=deadline +#SBATCH --qos=deadline +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --time=36:00:00 +#SBATCH --ntasks-per-node=4 +#SBATCH --cpus-per-task=4 +#SBATCH --mem-per-cpu=8G +#SBATCH --output=slurm-%j.out + + +PY_ARGS=${@:1} + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +export CUDA_LAUNCH_BLOCKING=1 + +export MASTER_ADDR=$(hostname) +export MASTER_PORT=45679 + +export PYTHONPATH="." +nvidia-smi + +# “srun” executes the script times +srun python simsiam/linear_eval_downstream_datasets.py \ +--data_dir="/projects/imagenet_synthetic/fereshteh_datasets/" \ +--checkpoint_dir="/projects/imagenet_synthetic/model_checkpoints/food101/evaluate_baseline"\ +--arch="resnet50" \ +--distributed_mode \ +--batch-size=1024 \ +--lars \ +--dataset_name="food101" \ +--pretrained_checkpoint="/projects/imagenet_synthetic/model_checkpoints/simsiam_baseline_2024-02-29-14-49/checkpoint_0099.pth.tar" +# --pretrained_checkpoint="/projects/imagenet_synthetic/model_checkpoints/simsiam_baseline_2024-02-29-14-49/checkpoint_0090.pth.tar" diff --git a/eval_scripts/food101/eval_food101_simsiam_icgan_100.slrm b/eval_scripts/food101/eval_food101_simsiam_icgan_100.slrm new file mode 100644 index 0000000..4545b1c --- /dev/null +++ b/eval_scripts/food101/eval_food101_simsiam_icgan_100.slrm @@ -0,0 +1,39 @@ +#!/bin/bash + +#SBATCH --job-name="simsiam_eval" +#SBATCH --partition=t4v2 +#SBATCH --account=deadline +#SBATCH --qos=deadline +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --time=36:00:00 +#SBATCH --ntasks-per-node=4 +#SBATCH --cpus-per-task=4 +#SBATCH --mem-per-cpu=8G +#SBATCH --output=slurm-%j.out + + +PY_ARGS=${@:1} + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +export CUDA_LAUNCH_BLOCKING=1 + +export MASTER_ADDR=$(hostname) +export MASTER_PORT=45679 + +export PYTHONPATH="." +nvidia-smi + +# “srun” executes the script times +srun python simsiam/linear_eval_downstream_datasets.py \ +--data_dir="/projects/imagenet_synthetic/fereshteh_datasets/" \ +--checkpoint_dir="/projects/imagenet_synthetic/model_checkpoints/food101/evaluate_icgan"\ +--arch="resnet50" \ +--distributed_mode \ +--batch-size=1024 \ +--lars \ +--dataset_name="food101" \ +--pretrained_checkpoint="/projects/imagenet_synthetic/model_checkpoints/simsiam_icgan_2024-02-29-18-40/checkpoint_0099.pth.tar" diff --git a/eval_scripts/food101/eval_food101_simsiam_stablediff_100.slrm b/eval_scripts/food101/eval_food101_simsiam_stablediff_100.slrm new file mode 100644 index 0000000..5cc53fc --- /dev/null +++ b/eval_scripts/food101/eval_food101_simsiam_stablediff_100.slrm @@ -0,0 +1,39 @@ +#!/bin/bash + +#SBATCH --job-name="simsiam_eval" +#SBATCH --partition=t4v2 +#SBATCH --account=deadline +#SBATCH --qos=deadline +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --time=36:00:00 +#SBATCH --ntasks-per-node=4 +#SBATCH --cpus-per-task=4 +#SBATCH --mem-per-cpu=8G +#SBATCH --output=slurm-%j.out + + +PY_ARGS=${@:1} + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +export CUDA_LAUNCH_BLOCKING=1 + +export MASTER_ADDR=$(hostname) +export MASTER_PORT=45679 + +export PYTHONPATH="." +nvidia-smi + +# “srun” executes the script times +srun python simsiam/linear_eval_downstream_datasets.py \ +--data_dir="/projects/imagenet_synthetic/fereshteh_datasets/" \ +--checkpoint_dir="/projects/imagenet_synthetic/model_checkpoints/food101/evaluate_stable_diff"\ +--arch="resnet50" \ +--distributed_mode \ +--batch-size=1024 \ +--lars \ +--dataset_name="food101" \ +--pretrained_checkpoint="/projects/imagenet_synthetic/model_checkpoints/simsiam_stablediff_2024-02-29-15-27/checkpoint_0099.pth.tar" diff --git a/eval_scripts/places365/eval_places365_original_simsiam_100.slrm b/eval_scripts/places365/eval_places365_original_simsiam_100.slrm new file mode 100644 index 0000000..a617e41 --- /dev/null +++ b/eval_scripts/places365/eval_places365_original_simsiam_100.slrm @@ -0,0 +1,39 @@ +#!/bin/bash + +#SBATCH --job-name="simsiam_eval" +#SBATCH --partition=t4v2 +#SBATCH --account=deadline +#SBATCH --qos=deadline +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --time=36:00:00 +#SBATCH --ntasks-per-node=4 +#SBATCH --cpus-per-task=4 +#SBATCH --mem-per-cpu=8G +#SBATCH --output=slurm-%j.out + + +PY_ARGS=${@:1} + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +export CUDA_LAUNCH_BLOCKING=1 + +export MASTER_ADDR=$(hostname) +export MASTER_PORT=45679 + +export PYTHONPATH="." +nvidia-smi + +# “srun” executes the script times +srun python simsiam/linear_eval_downstream_datasets.py \ +--data_dir="/projects/imagenet_synthetic/fereshteh_datasets/places365/" \ +--checkpoint_dir="/projects/imagenet_synthetic/model_checkpoints/places365/evaluate_original"\ +--arch="resnet50" \ +--distributed_mode \ +--batch-size=1024 \ +--lars \ +--dataset_name="places365" \ +--pretrained_checkpoint="/projects/imagenet_synthetic/model_checkpoints/_original_simsiam/checkpoint_0099.pth.tar" diff --git a/eval_scripts/places365/eval_places365_simsiam_baseline_100.slrm b/eval_scripts/places365/eval_places365_simsiam_baseline_100.slrm new file mode 100644 index 0000000..0160209 --- /dev/null +++ b/eval_scripts/places365/eval_places365_simsiam_baseline_100.slrm @@ -0,0 +1,39 @@ +#!/bin/bash + +#SBATCH --job-name="simsiam_eval" +#SBATCH --partition=t4v2 +#SBATCH --account=deadline +#SBATCH --qos=deadline +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --time=36:00:00 +#SBATCH --ntasks-per-node=4 +#SBATCH --cpus-per-task=4 +#SBATCH --mem-per-cpu=8G +#SBATCH --output=slurm-%j.out + + +PY_ARGS=${@:1} + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +export CUDA_LAUNCH_BLOCKING=1 + +export MASTER_ADDR=$(hostname) +export MASTER_PORT=45679 + +export PYTHONPATH="." +nvidia-smi + +# “srun” executes the script times +srun python simsiam/linear_eval_downstream_datasets.py \ +--data_dir="/projects/imagenet_synthetic/fereshteh_datasets/places365/" \ +--checkpoint_dir="/projects/imagenet_synthetic/model_checkpoints/places365/evaluate_baseline"\ +--arch="resnet50" \ +--distributed_mode \ +--batch-size=1024 \ +--lars \ +--dataset_name="places365" \ +--pretrained_checkpoint="/projects/imagenet_synthetic/model_checkpoints/simsiam_baseline_2024-02-29-14-49/checkpoint_0099.pth.tar" \ No newline at end of file diff --git a/eval_scripts/places365/eval_places365_simsiam_icgan_100.slrm b/eval_scripts/places365/eval_places365_simsiam_icgan_100.slrm new file mode 100644 index 0000000..48975e1 --- /dev/null +++ b/eval_scripts/places365/eval_places365_simsiam_icgan_100.slrm @@ -0,0 +1,39 @@ +#!/bin/bash + +#SBATCH --job-name="simsiam_eval" +#SBATCH --partition=t4v2 +#SBATCH --account=deadline +#SBATCH --qos=deadline +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --time=36:00:00 +#SBATCH --ntasks-per-node=4 +#SBATCH --cpus-per-task=4 +#SBATCH --mem-per-cpu=8G +#SBATCH --output=slurm-%j.out + + +PY_ARGS=${@:1} + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +export CUDA_LAUNCH_BLOCKING=1 + +export MASTER_ADDR=$(hostname) +export MASTER_PORT=45679 + +export PYTHONPATH="." +nvidia-smi + +# “srun” executes the script times +srun python simsiam/linear_eval_downstream_datasets.py \ +--data_dir="/projects/imagenet_synthetic/fereshteh_datasets/places365/" \ +--checkpoint_dir="/projects/imagenet_synthetic/model_checkpoints/places365/evaluate_icgan"\ +--arch="resnet50" \ +--distributed_mode \ +--batch-size=1024 \ +--lars \ +--dataset_name="places365" \ +--pretrained_checkpoint="/projects/imagenet_synthetic/model_checkpoints/simsiam_icgan_2024-02-29-18-40/checkpoint_0099.pth.tar" diff --git a/eval_scripts/places365/eval_places365_simsiam_stablediff_100.slrm b/eval_scripts/places365/eval_places365_simsiam_stablediff_100.slrm new file mode 100644 index 0000000..cd8de93 --- /dev/null +++ b/eval_scripts/places365/eval_places365_simsiam_stablediff_100.slrm @@ -0,0 +1,39 @@ +#!/bin/bash + +#SBATCH --job-name="simsiam_eval" +#SBATCH --partition=t4v2 +#SBATCH --account=deadline +#SBATCH --qos=deadline +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --time=36:00:00 +#SBATCH --ntasks-per-node=4 +#SBATCH --cpus-per-task=4 +#SBATCH --mem-per-cpu=8G +#SBATCH --output=slurm-%j.out + + +PY_ARGS=${@:1} + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +export CUDA_LAUNCH_BLOCKING=1 + +export MASTER_ADDR=$(hostname) +export MASTER_PORT=45679 + +export PYTHONPATH="." +nvidia-smi + +# “srun” executes the script times +srun python simsiam/linear_eval_downstream_datasets.py \ +--data_dir="/projects/imagenet_synthetic/fereshteh_datasets/places365/" \ +--checkpoint_dir="/projects/imagenet_synthetic/model_checkpoints/places365/evaluate_stable_diff"\ +--arch="resnet50" \ +--distributed_mode \ +--batch-size=1024 \ +--lars \ +--dataset_name="places365" \ +--pretrained_checkpoint="/projects/imagenet_synthetic/model_checkpoints/simsiam_stablediff_2024-02-29-15-27/checkpoint_0099.pth.tar" diff --git a/eval_simsiam_baseline.slrm b/eval_simsiam_baseline.slrm new file mode 100644 index 0000000..44a6299 --- /dev/null +++ b/eval_simsiam_baseline.slrm @@ -0,0 +1,38 @@ +#!/bin/bash + +#SBATCH --job-name="simsiam_eval" +#SBATCH --partition=a40 +#SBATCH --account=deadline +#SBATCH --qos=deadline +#SBATCH --time=72:00:00 +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --ntasks-per-node=4 +#SBATCH --cpus-per-task=4 +#SBATCH --mem-per-cpu=8G +#SBATCH --output=slurm-%j.out + +PY_ARGS=${@:1} + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +export CUDA_LAUNCH_BLOCKING=1 + +export MASTER_ADDR=$(hostname) +export MASTER_PORT=45679 + +export PYTHONPATH="." +nvidia-smi + +# “srun” executes the script times +srun python simsiam/linear_eval.py \ +--data_dir="/scratch/ssd004/datasets/imagenet256" \ +--arch="resnet50" \ +--distributed_mode \ +--batch-size=1024 \ +--lars \ +--pretrained_checkpoint="/projects/imagenet_synthetic/model_checkpoints/simsiam_baseline_2024-02-29-14-49/checkpoint_0099.pth.tar" \ + +# --checkpoint_dir="/projects/imagenet_synthetic/model_checkpoints/eval_original_simsiam_baseline" diff --git a/eval_simsiam_icgan.slrm b/eval_simsiam_icgan.slrm new file mode 100644 index 0000000..6c1070a --- /dev/null +++ b/eval_simsiam_icgan.slrm @@ -0,0 +1,38 @@ +#!/bin/bash + +#SBATCH --job-name="simsiam_eval" +#SBATCH --partition=a40 +#SBATCH --account=deadline +#SBATCH --qos=deadline +#SBATCH --time=72:00:00 +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --ntasks-per-node=4 +#SBATCH --cpus-per-task=4 +#SBATCH --mem-per-cpu=8G +#SBATCH --output=slurm-%j.out + +PY_ARGS=${@:1} + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +export CUDA_LAUNCH_BLOCKING=1 + +export MASTER_ADDR=$(hostname) +export MASTER_PORT=45679 + +export PYTHONPATH="." +nvidia-smi + +# “srun” executes the script times +srun python simsiam/linear_eval.py \ +--data_dir="/scratch/ssd004/datasets/imagenet256" \ +--arch="resnet50" \ +--distributed_mode \ +--batch-size=1024 \ +--lars \ +--pretrained_checkpoint="/projects/imagenet_synthetic/model_checkpoints/simsiam_icgan_2024-02-29-18-40/checkpoint_0099.pth.tar" \ + +# --checkpoint_dir="/projects/imagenet_synthetic/model_checkpoints/eval_simsiam_icgan_ep0099" \ No newline at end of file diff --git a/eval_simsiam_multinode.slrm b/eval_simsiam_multinode.slrm new file mode 100644 index 0000000..1696a27 --- /dev/null +++ b/eval_simsiam_multinode.slrm @@ -0,0 +1,51 @@ +#!/bin/bash + +#SBATCH --job-name="simsiam_multi_train" +#SBATCH --partition=a40 +#SBATCH --qos=m2 +#SBATCH --nodes=2 +#SBATCH --gres=gpu:a40:4 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=32 +#SBATCH --mem=0 +#SBATCH --output=multinode-%j.out +#SBATCH --error=multinode-%j.err +#SBATCH --open-mode=append +#SBATCH --wait-all-nodes=1 +#SBATCH --time=08:00:00 + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export NCCL_IB_DISABLE=1 # Our cluster does not have InfiniBand. We need to disable usage using this flag. +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +# export CUDA_LAUNCH_BLOCKING=1 + + +export MASTER_ADDR="$(hostname --fqdn)" +export MASTER_PORT="$(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1])')" +export RDVZ_ID=$RANDOM +echo "RDZV Endpoint $MASTER_ADDR:$MASTER_PORT" + +export PYTHONPATH="." +nvidia-smi + +srun -p $SLURM_JOB_PARTITION \ + -c $SLURM_CPUS_ON_NODE \ + -N $SLURM_JOB_NUM_NODES \ + --mem=0 \ + --gres=gpu:$SLURM_JOB_PARTITION:$SLURM_GPUS_ON_NODE \ + bash -c 'torchrun \ + --nproc-per-node=$SLURM_GPUS_ON_NODE \ + --nnodes=$SLURM_JOB_NUM_NODES \ + --rdzv-endpoint $MASTER_ADDR:$MASTER_PORT \ + --rdzv-id $RDVZ_ID \ + --rdzv-backend c10d \ + simsiam/adil_linear_eval_original_logs.py \ + --data_dir="/scratch/ssd004/datasets/imagenet256" \ + --arch="resnet50" \ + --distributed_mode \ + --batch-size=1024 \ + --lars \ + --pretrained_checkpoint="/projects/imagenet_synthetic/model_checkpoints/simsiam_stablediff_2024-02-29-15-27/checkpoint_0099.pth.tar" / + ' \ No newline at end of file diff --git a/eval_simsiam_singlenode.slrm b/eval_simsiam_singlenode.slrm new file mode 100644 index 0000000..efeacb2 --- /dev/null +++ b/eval_simsiam_singlenode.slrm @@ -0,0 +1,33 @@ +#!/bin/bash + +#SBATCH --job-name="simsiam_single_train" +#SBATCH --partition=a40 +#SBATCH --qos=m +#SBATCH --nodes=1 +#SBATCH --gres=gpu:a40:4 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=32 +#SBATCH --mem=0 +#SBATCH --output=singlenode-%j.out +#SBATCH --error=singlenode-%j.err +#SBATCH --open-mode=append +#SBATCH --wait-all-nodes=1 +#SBATCH --time=12:00:00 + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export NCCL_IB_DISABLE=1 # Our cluster does not have InfiniBand. We need to disable usage using this flag. +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +# export CUDA_LAUNCH_BLOCKING=1 + +export PYTHONPATH="." +nvidia-smi + +torchrun --nproc-per-node=4 --nnodes=1 simsiam/adil_linear_eval_original_logs.py \ + --data_dir="/scratch/ssd004/datasets/imagenet256" \ + --arch="resnet50" \ + --distributed_mode \ + --batch-size=1024 \ + --lars \ + --pretrained_checkpoint="/projects/imagenet_synthetic/model_checkpoints/simsiam_stablediff_2024-02-29-15-27/checkpoint_0099.pth.tar" \ No newline at end of file diff --git a/eval_simsiam_stablediff.slrm b/eval_simsiam_stablediff.slrm new file mode 100644 index 0000000..9274bbe --- /dev/null +++ b/eval_simsiam_stablediff.slrm @@ -0,0 +1,38 @@ +#!/bin/bash + +#SBATCH --job-name="simsiam_eval" +#SBATCH --partition=a40 +#SBATCH --time=72:00:00 +#SBATCH --account=deadline +#SBATCH --qos=deadline +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --ntasks-per-node=4 +#SBATCH --cpus-per-task=4 +#SBATCH --mem-per-cpu=8G +#SBATCH --output=slurm-%j.out + +PY_ARGS=${@:1} + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +export CUDA_LAUNCH_BLOCKING=1 + +export MASTER_ADDR=$(hostname) +export MASTER_PORT=45679 + +export PYTHONPATH="." +nvidia-smi + +# “srun” executes the script times +srun python simsiam/linear_eval.py \ +--data_dir="/scratch/ssd004/datasets/imagenet256" \ +--arch="resnet50" \ +--distributed_mode \ +--batch-size=1024 \ +--lars \ +--pretrained_checkpoint="/projects/imagenet_synthetic/model_checkpoints/simsiam_stablediff_2024-02-29-15-27/checkpoint_0099.pth.tar" \ + +# --checkpoint_dir="/projects/imagenet_synthetic/model_checkpoints/eval_simsiam_stablediff_ep0099" \ No newline at end of file diff --git a/original_eval_scripts/CIFAR10/baseline.slrm b/original_eval_scripts/CIFAR10/baseline.slrm new file mode 100644 index 0000000..edcdb76 --- /dev/null +++ b/original_eval_scripts/CIFAR10/baseline.slrm @@ -0,0 +1,47 @@ +#!/bin/bash + +#SBATCH --job-name="cifar" +#SBATCH --partition=a40 +#SBATCH --qos=deadline +#SBATCH --account=deadline +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=32 +#SBATCH --mem=0 +#SBATCH --output=slurm-cifar10_baseline_160_%j.out +#SBATCH --open-mode=append +#SBATCH --wait-all-nodes=1 +#SBATCH --time=72:00:00 + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export NCCL_IB_DISABLE=1 # Our cluster does not have InfiniBand. We need to disable usage using this flag. +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +# export CUDA_LAUNCH_BLOCKING=1 +export MASTER_ADDR="$(hostname --fqdn)" +export MASTER_PORT="$(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1])')" +export RDVZ_ID=$RANDOM +echo "RDZV Endpoint $MASTER_ADDR:$MASTER_PORT" + +echo $MASTER_ADDR +echo $MASTER_PORT + +export PYTHONPATH="." +nvidia-smi + +python simsiam/linear_eval_original_code.py \ + --data="/projects/imagenet_synthetic/fereshteh_datasets" \ + --arch="resnet50" \ + --multiprocessing-distributed \ + --lars \ + --batch-size=4096 \ + --epochs=100 \ + -j=16 \ + --world-size 1 \ + --rank 0 \ + --pretrained="/projects/imagenet_synthetic/model_checkpoints/simsiam_baseline_seed43_bs128_rforig_2024-03-05-12-27/checkpoint_0160.pth.tar"\ + --dist-url "tcp://$MASTER_ADDR:$MASTER_PORT" \ + --dataset_name="cifar10" \ + --num_classes=10 diff --git a/original_eval_scripts/CIFAR10/icgan.slrm b/original_eval_scripts/CIFAR10/icgan.slrm new file mode 100644 index 0000000..b416f0f --- /dev/null +++ b/original_eval_scripts/CIFAR10/icgan.slrm @@ -0,0 +1,47 @@ +#!/bin/bash + +#SBATCH --job-name="cifar" +#SBATCH --partition=a40 +#SBATCH --qos=deadline +#SBATCH --account=deadline +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=32 +#SBATCH --mem=0 +#SBATCH --output=cifar10_baseline_160_%j.out +#SBATCH --error=cifar10_baseline_160_%j.err +#SBATCH --open-mode=append +#SBATCH --wait-all-nodes=1 +#SBATCH --time=72:00:00 + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export NCCL_IB_DISABLE=1 # Our cluster does not have InfiniBand. We need to disable usage using this flag. +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +# export CUDA_LAUNCH_BLOCKING=1 +export MASTER_ADDR="$(hostname --fqdn)" +export MASTER_PORT="$(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1])')" +export RDVZ_ID=$RANDOM +echo "RDZV Endpoint $MASTER_ADDR:$MASTER_PORT" + +echo $MASTER_ADDR +echo $MASTER_PORT + +export PYTHONPATH="." +nvidia-smi + +python simsiam/linear_eval_original_code.py \ + --data="/projects/imagenet_synthetic/fereshteh_datasets" \ + --arch="resnet50" \ + --multiprocessing-distributed \ + --lars \ + --batch-size=4096 \ + --epochs=100 \ + -j=16 \ + --world-size 1 \ + --rank 0 \ + --pretrained="/projects/imagenet_synthetic/model_checkpoints/simsiam_icgan_seed43_bs128_rforig_2024-03-05-12-52/checkpoint_0160.pth.tar"\ + --dist-url "tcp://$MASTER_ADDR:$MASTER_PORT" \ + --dataset_name="cifar10" \ No newline at end of file diff --git a/original_eval_scripts/CIFAR10/icgan_ab.slrm b/original_eval_scripts/CIFAR10/icgan_ab.slrm new file mode 100644 index 0000000..09558fa --- /dev/null +++ b/original_eval_scripts/CIFAR10/icgan_ab.slrm @@ -0,0 +1,49 @@ +#!/bin/bash + +#SBATCH --job-name="cifar" +#SBATCH --partition=a40 +#SBATCH --qos=deadline +#SBATCH --account=deadline +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=32 +#SBATCH --mem=0 +#SBATCH --output=icgan_ab_cifar10_%j.out +#SBATCH --error=icgan_ab_cifar10_%j.err +#SBATCH --open-mode=append +#SBATCH --wait-all-nodes=1 +#SBATCH --time=72:00:00 + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export NCCL_IB_DISABLE=1 # Our cluster does not have InfiniBand. We need to disable usage using this flag. +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +# export CUDA_LAUNCH_BLOCKING=1 +export MASTER_ADDR="$(hostname --fqdn)" +export MASTER_PORT="$(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1])')" +export RDVZ_ID=$RANDOM +echo "RDZV Endpoint $MASTER_ADDR:$MASTER_PORT" + +echo $MASTER_ADDR +echo $MASTER_PORT + +export PYTHONPATH="." +nvidia-smi + +python simsiam/linear_eval_original_code.py \ + --data="/projects/imagenet_synthetic/fereshteh_datasets" \ + --arch="resnet50" \ + --multiprocessing-distributed \ + --lars \ + --batch-size=4096 \ + --epochs=100 \ + -j=16 \ + --world-size 1 \ + --rank 0 \ + --pretrained="/ssd003/projects/aieng/genssl/swav_pretrained.pth.tar" \ + --dist-url "tcp://$MASTER_ADDR:$MASTER_PORT" \ + --dataset_name="cifar10" \ + --num_classes=10 \ + --ablation_mode="icgan" diff --git a/original_eval_scripts/CIFAR100/baseline.slrm b/original_eval_scripts/CIFAR100/baseline.slrm new file mode 100644 index 0000000..1c9ed80 --- /dev/null +++ b/original_eval_scripts/CIFAR100/baseline.slrm @@ -0,0 +1,48 @@ +#!/bin/bash + +#SBATCH --job-name="cifar" +#SBATCH --partition=a40 +#SBATCH --qos=deadline +#SBATCH --account=deadline +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=32 +#SBATCH --mem=0 +#SBATCH --output=cifar100_baseline_160_%j.out +#SBATCH --error=cifar100_baseline_160_%j.err +#SBATCH --open-mode=append +#SBATCH --wait-all-nodes=1 +#SBATCH --time=72:00:00 + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export NCCL_IB_DISABLE=1 # Our cluster does not have InfiniBand. We need to disable usage using this flag. +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +# export CUDA_LAUNCH_BLOCKING=1 +export MASTER_ADDR="$(hostname --fqdn)" +export MASTER_PORT="$(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1])')" +export RDVZ_ID=$RANDOM +echo "RDZV Endpoint $MASTER_ADDR:$MASTER_PORT" + +echo $MASTER_ADDR +echo $MASTER_PORT + +export PYTHONPATH="." +nvidia-smi + +python simsiam/linear_eval_original_code.py \ + --data="/projects/imagenet_synthetic/fereshteh_datasets" \ + --arch="resnet50" \ + --multiprocessing-distributed \ + --lars \ + --batch-size=4096 \ + --epochs=100 \ + -j=16 \ + --world-size 1 \ + --rank 0 \ + --pretrained="/projects/imagenet_synthetic/model_checkpoints/simsiam_baseline_seed43_bs128_rforig_2024-03-05-12-27/checkpoint_0160.pth.tar"\ + --dist-url "tcp://$MASTER_ADDR:$MASTER_PORT" \ + --dataset_name="cifar100" \ + --num_classes=100 \ No newline at end of file diff --git a/original_eval_scripts/CIFAR100/icgan_ab.slrm b/original_eval_scripts/CIFAR100/icgan_ab.slrm new file mode 100644 index 0000000..97b6373 --- /dev/null +++ b/original_eval_scripts/CIFAR100/icgan_ab.slrm @@ -0,0 +1,49 @@ +#!/bin/bash + +#SBATCH --job-name="cifar" +#SBATCH --partition=a40 +#SBATCH --qos=deadline +#SBATCH --account=deadline +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=32 +#SBATCH --mem=0 +#SBATCH --output=icgan_ab_cifar100_%j.out +#SBATCH --error=icgan_ab_cifar100_%j.err +#SBATCH --open-mode=append +#SBATCH --wait-all-nodes=1 +#SBATCH --time=72:00:00 + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export NCCL_IB_DISABLE=1 # Our cluster does not have InfiniBand. We need to disable usage using this flag. +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +# export CUDA_LAUNCH_BLOCKING=1 +export MASTER_ADDR="$(hostname --fqdn)" +export MASTER_PORT="$(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1])')" +export RDVZ_ID=$RANDOM +echo "RDZV Endpoint $MASTER_ADDR:$MASTER_PORT" + +echo $MASTER_ADDR +echo $MASTER_PORT + +export PYTHONPATH="." +nvidia-smi + +python simsiam/linear_eval_original_code.py \ + --data="/projects/imagenet_synthetic/fereshteh_datasets" \ + --arch="resnet50" \ + --multiprocessing-distributed \ + --lars \ + --batch-size=4096 \ + --epochs=100 \ + -j=16 \ + --world-size 1 \ + --rank 0 \ + --pretrained="/ssd003/projects/aieng/genssl/swav_pretrained.pth.tar" \ + --dist-url "tcp://$MASTER_ADDR:$MASTER_PORT" \ + --dataset_name="cifar100" \ + --num_classes=100 \ + --ablation_mode="icgan" \ No newline at end of file diff --git a/original_eval_scripts/INaturalist/baseline.slrm b/original_eval_scripts/INaturalist/baseline.slrm new file mode 100644 index 0000000..84af317 --- /dev/null +++ b/original_eval_scripts/INaturalist/baseline.slrm @@ -0,0 +1,48 @@ +#!/bin/bash + +#SBATCH --job-name="inaturalist" +#SBATCH --partition=t4v2 +#SBATCH --qos=deadline +#SBATCH --account=deadline +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=32 +#SBATCH --mem=0 +#SBATCH --output=inaturalist_baseline_160_%j.out +#SBATCH --error=inaturalist_baseline_160_%j.err +#SBATCH --open-mode=append +#SBATCH --wait-all-nodes=1 +#SBATCH --time=72:00:00 + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export NCCL_IB_DISABLE=1 # Our cluster does not have InfiniBand. We need to disable usage using this flag. +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +# export CUDA_LAUNCH_BLOCKING=1 +export MASTER_ADDR="$(hostname --fqdn)" +export MASTER_PORT="$(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1])')" +export RDVZ_ID=$RANDOM +echo "RDZV Endpoint $MASTER_ADDR:$MASTER_PORT" + +echo $MASTER_ADDR +echo $MASTER_PORT + +export PYTHONPATH="." +nvidia-smi + +python simsiam/linear_eval_original_code.py \ + --data="/datasets/inat_comp/2018/" \ + --arch="resnet50" \ + --multiprocessing-distributed \ + --lars \ + --batch-size=4096 \ + --epochs=100 \ + -j=16 \ + --world-size 1 \ + --rank 0 \ + --pretrained="/projects/imagenet_synthetic/model_checkpoints/simsiam_baseline_seed43_bs128_rforig_2024-03-05-12-27/checkpoint_0160.pth.tar"\ + --dist-url "tcp://$MASTER_ADDR:$MASTER_PORT" \ + --dataset_name="INaturalist" \ + --num_classes=8142 \ No newline at end of file diff --git a/original_eval_scripts/INaturalist/icgan.slrm b/original_eval_scripts/INaturalist/icgan.slrm new file mode 100644 index 0000000..fbe6a3e --- /dev/null +++ b/original_eval_scripts/INaturalist/icgan.slrm @@ -0,0 +1,48 @@ +#!/bin/bash + +#SBATCH --job-name="inaturalist" +#SBATCH --partition=a40 +#SBATCH --qos=deadline +#SBATCH --account=deadline +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=32 +#SBATCH --mem=0 +#SBATCH --output=inaturalist_icgan_160_%j.out +#SBATCH --error=inaturalist_icgan_160_%j.err +#SBATCH --open-mode=append +#SBATCH --wait-all-nodes=1 +#SBATCH --time=72:00:00 + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export NCCL_IB_DISABLE=1 # Our cluster does not have InfiniBand. We need to disable usage using this flag. +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +# export CUDA_LAUNCH_BLOCKING=1 +export MASTER_ADDR="$(hostname --fqdn)" +export MASTER_PORT="$(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1])')" +export RDVZ_ID=$RANDOM +echo "RDZV Endpoint $MASTER_ADDR:$MASTER_PORT" + +echo $MASTER_ADDR +echo $MASTER_PORT + +export PYTHONPATH="." +nvidia-smi + +python simsiam/linear_eval_original_code.py \ + --data="/datasets/inat_comp/2018/" \ + --arch="resnet50" \ + --multiprocessing-distributed \ + --lars \ + --batch-size=4096 \ + --epochs=100 \ + -j=16 \ + --world-size 1 \ + --rank 0 \ + --pretrained="/projects/imagenet_synthetic/model_checkpoints/simsiam_icgan_seed43_bs128_rforig_2024-03-05-12-52/checkpoint_0160.pth.tar"\ + --dist-url "tcp://$MASTER_ADDR:$MASTER_PORT" \ + --dataset_name="INaturalist" \ + --num_classes=8142 \ No newline at end of file diff --git a/original_eval_scripts/INaturalist/icgan_ab.slrm b/original_eval_scripts/INaturalist/icgan_ab.slrm new file mode 100644 index 0000000..d579a3f --- /dev/null +++ b/original_eval_scripts/INaturalist/icgan_ab.slrm @@ -0,0 +1,49 @@ +#!/bin/bash + +#SBATCH --job-name="inaturalist" +#SBATCH --partition=a40 +#SBATCH --qos=deadline +#SBATCH --account=deadline +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=32 +#SBATCH --mem=0 +#SBATCH --output=icgan_ab_inaturalist_%j.out +#SBATCH --error=icgan_ab_inaturalist_%j.err +#SBATCH --open-mode=append +#SBATCH --wait-all-nodes=1 +#SBATCH --time=72:00:00 + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export NCCL_IB_DISABLE=1 # Our cluster does not have InfiniBand. We need to disable usage using this flag. +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +# export CUDA_LAUNCH_BLOCKING=1 +export MASTER_ADDR="$(hostname --fqdn)" +export MASTER_PORT="$(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1])')" +export RDVZ_ID=$RANDOM +echo "RDZV Endpoint $MASTER_ADDR:$MASTER_PORT" + +echo $MASTER_ADDR +echo $MASTER_PORT + +export PYTHONPATH="." +nvidia-smi + +python simsiam/linear_eval_original_code.py \ + --data="/datasets/inat_comp/2018/" \ + --arch="resnet50" \ + --multiprocessing-distributed \ + --lars \ + --batch-size=4096 \ + --epochs=100 \ + -j=16 \ + --world-size 1 \ + --rank 0 \ + --pretrained="/ssd003/projects/aieng/genssl/swav_pretrained.pth.tar" \ + --dist-url "tcp://$MASTER_ADDR:$MASTER_PORT" \ + --dataset_name="INaturalist" \ + --num_classes=8142 \ + --ablation_mode="icgan" \ No newline at end of file diff --git a/original_eval_scripts/INaturalist/stablediff.slrm b/original_eval_scripts/INaturalist/stablediff.slrm new file mode 100644 index 0000000..a3baaa1 --- /dev/null +++ b/original_eval_scripts/INaturalist/stablediff.slrm @@ -0,0 +1,48 @@ +#!/bin/bash + +#SBATCH --job-name="inaturalist" +#SBATCH --partition=a40 +#SBATCH --qos=deadline +#SBATCH --account=deadline +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=32 +#SBATCH --mem=0 +#SBATCH --output=inaturalist_stablediff_160_%j.out +#SBATCH --error=inaturalist_stablediff_160_%j.err +#SBATCH --open-mode=append +#SBATCH --wait-all-nodes=1 +#SBATCH --time=72:00:00 + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export NCCL_IB_DISABLE=1 # Our cluster does not have InfiniBand. We need to disable usage using this flag. +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +# export CUDA_LAUNCH_BLOCKING=1 +export MASTER_ADDR="$(hostname --fqdn)" +export MASTER_PORT="$(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1])')" +export RDVZ_ID=$RANDOM +echo "RDZV Endpoint $MASTER_ADDR:$MASTER_PORT" + +echo $MASTER_ADDR +echo $MASTER_PORT + +export PYTHONPATH="." +nvidia-smi + +python simsiam/linear_eval_original_code.py \ + --data="/datasets/inat_comp/2018/" \ + --arch="resnet50" \ + --multiprocessing-distributed \ + --lars \ + --batch-size=4096 \ + --epochs=100 \ + -j=16 \ + --world-size 1 \ + --rank 0 \ + --pretrained="/projects/imagenet_synthetic/model_checkpoints/simsiam_stablediff_p0p5_seed43_2024-03-05-13-39/checkpoint_0160.pth.tar"\ + --dist-url "tcp://$MASTER_ADDR:$MASTER_PORT" \ + --dataset_name="INaturalist" \ + --num_classes=8142 \ No newline at end of file diff --git a/original_eval_scripts/food101/baseline.slrm b/original_eval_scripts/food101/baseline.slrm new file mode 100644 index 0000000..8e9fe85 --- /dev/null +++ b/original_eval_scripts/food101/baseline.slrm @@ -0,0 +1,47 @@ +#!/bin/bash + +#SBATCH --job-name="food101" +#SBATCH --partition=a40 +#SBATCH --qos=deadline +#SBATCH --account=deadline +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=32 +#SBATCH --mem=0 +#SBATCH --output=slurm-food101_baseline_160_%j.out +#SBATCH --open-mode=append +#SBATCH --wait-all-nodes=1 +#SBATCH --time=72:00:00 + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export NCCL_IB_DISABLE=1 # Our cluster does not have InfiniBand. We need to disable usage using this flag. +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +# export CUDA_LAUNCH_BLOCKING=1 +export MASTER_ADDR="$(hostname --fqdn)" +export MASTER_PORT="$(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1])')" +export RDVZ_ID=$RANDOM +echo "RDZV Endpoint $MASTER_ADDR:$MASTER_PORT" + +echo $MASTER_ADDR +echo $MASTER_PORT + +export PYTHONPATH="." +nvidia-smi + +python simsiam/linear_eval_original_code.py \ + --data="/projects/imagenet_synthetic/fereshteh_datasets" \ + --arch="resnet50" \ + --multiprocessing-distributed \ + --lars \ + --batch-size=4096 \ + --epochs=100 \ + -j=16 \ + --world-size 1 \ + --rank 0 \ + --pretrained="/projects/imagenet_synthetic/model_checkpoints/simsiam_baseline_seed43_bs128_rforig_2024-03-05-12-27/checkpoint_0160.pth.tar"\ + --dist-url "tcp://$MASTER_ADDR:$MASTER_PORT" \ + --dataset_name="food101" \ + --num_classes=101 \ No newline at end of file diff --git a/original_eval_scripts/food101/icgan_ab.slrm b/original_eval_scripts/food101/icgan_ab.slrm new file mode 100644 index 0000000..30ce13e --- /dev/null +++ b/original_eval_scripts/food101/icgan_ab.slrm @@ -0,0 +1,49 @@ +#!/bin/bash + +#SBATCH --job-name="food101" +#SBATCH --partition=a40 +#SBATCH --qos=deadline +#SBATCH --account=deadline +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=32 +#SBATCH --mem=0 +#SBATCH --output=icgan_ab_food101_%j.out +#SBATCH --error=icgan_ab_food101_%j.err +#SBATCH --open-mode=append +#SBATCH --wait-all-nodes=1 +#SBATCH --time=72:00:00 + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export NCCL_IB_DISABLE=1 # Our cluster does not have InfiniBand. We need to disable usage using this flag. +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +# export CUDA_LAUNCH_BLOCKING=1 +export MASTER_ADDR="$(hostname --fqdn)" +export MASTER_PORT="$(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1])')" +export RDVZ_ID=$RANDOM +echo "RDZV Endpoint $MASTER_ADDR:$MASTER_PORT" + +echo $MASTER_ADDR +echo $MASTER_PORT + +export PYTHONPATH="." +nvidia-smi + +python simsiam/linear_eval_original_code.py \ + --data="/projects/imagenet_synthetic/fereshteh_datasets" \ + --arch="resnet50" \ + --multiprocessing-distributed \ + --lars \ + --batch-size=4096 \ + --epochs=100 \ + -j=16 \ + --world-size 1 \ + --rank 0 \ + --pretrained="/ssd003/projects/aieng/genssl/swav_pretrained.pth.tar" \ + --dist-url "tcp://$MASTER_ADDR:$MASTER_PORT" \ + --dataset_name="food101" \ + --num_classes=101 \ + --ablation_mode="icgan" \ No newline at end of file diff --git a/original_eval_scripts/imagenet/baseline.slrm b/original_eval_scripts/imagenet/baseline.slrm new file mode 100644 index 0000000..a9b6e1e --- /dev/null +++ b/original_eval_scripts/imagenet/baseline.slrm @@ -0,0 +1,46 @@ +#!/bin/bash + +#SBATCH --job-name="sana_eval" +#SBATCH --partition=t4v2 + +#SBATCH --qos=deadline +#SBATCH --account=deadline +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=32 +#SBATCH --mem=0 +#SBATCH --output=singlenode_stablediff_160_%j.out +#SBATCH --error=singlenode_stablediff_160_%j.err +#SBATCH --open-mode=append +#SBATCH --wait-all-nodes=1 +#SBATCH --time=72:00:00 + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export NCCL_IB_DISABLE=1 # Our cluster does not have InfiniBand. We need to disable usage using this flag. +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +# export CUDA_LAUNCH_BLOCKING=1 +export MASTER_ADDR="$(hostname --fqdn)" +export MASTER_PORT="$(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1])')" +export RDVZ_ID=$RANDOM +echo "RDZV Endpoint $MASTER_ADDR:$MASTER_PORT" + +echo $MASTER_ADDR +echo $MASTER_PORT + +export PYTHONPATH="." +nvidia-smi + +python simsiam/linear_eval_original_code.py \ + --data="/scratch/ssd004/datasets/imagenet256" \ + --arch="resnet50" \ + --multiprocessing-distributed \ + --lars --batch-size=2048 \ + --epochs=100 \ + -j=16 \ + --world-size 1 \ + --rank 0 \ + --pretrained="/projects/imagenet_synthetic/model_checkpoints/simsiam_stablediff_p0p5_seed43_2024-03-05-13-39/checkpoint_0160.pth.tar" \ + --dist-url "tcp://$MASTER_ADDR:$MASTER_PORT" diff --git a/original_eval_scripts/imagenet/clip.slrm b/original_eval_scripts/imagenet/clip.slrm new file mode 100644 index 0000000..25f8638 --- /dev/null +++ b/original_eval_scripts/imagenet/clip.slrm @@ -0,0 +1,43 @@ +#!/bin/bash + +#SBATCH --job-name="clip_eval" +#SBATCH --partition=a40 +#SBATCH --qos=deadline +#SBATCH --account=deadline +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=32 +#SBATCH --mem=0 +#SBATCH --output=clip_ab_imagenet_%j.out +#SBATCH --error=clip_ab_imagenet_%j.err +#SBATCH --open-mode=append +#SBATCH --wait-all-nodes=1 +#SBATCH --time=72:00:00 + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export NCCL_IB_DISABLE=1 # Our cluster does not have InfiniBand. We need to disable usage using this flag. +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +# export CUDA_LAUNCH_BLOCKING=1 +export MASTER_ADDR="$(hostname --fqdn)" +export MASTER_PORT="$(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1])')" +export RDVZ_ID=$RANDOM +echo "RDZV Endpoint $MASTER_ADDR:$MASTER_PORT" + +echo $MASTER_ADDR +echo $MASTER_PORT + +export PYTHONPATH="." +nvidia-smi + +python simsiam/linear_eval_original_code_clip.py \ + --data="/scratch/ssd004/datasets/imagenet256" \ + --multiprocessing-distributed \ + --lars --batch-size=2048 \ + --epochs=100 \ + -j=16 \ + --world-size 1 \ + --rank 0 \ + --dist-url "tcp://$MASTER_ADDR:$MASTER_PORT" \ \ No newline at end of file diff --git a/original_eval_scripts/imagenet/icgan_ab.slrm b/original_eval_scripts/imagenet/icgan_ab.slrm new file mode 100644 index 0000000..0020867 --- /dev/null +++ b/original_eval_scripts/imagenet/icgan_ab.slrm @@ -0,0 +1,46 @@ +#!/bin/bash + +#SBATCH --job-name="sana_eval" +#SBATCH --partition=a40 +#SBATCH --qos=deadline +#SBATCH --account=deadline +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=32 +#SBATCH --mem=0 +#SBATCH --output=icgan_ab_imagenet_%j.out +#SBATCH --error=icgan_ab_imagenet_%j.err +#SBATCH --open-mode=append +#SBATCH --wait-all-nodes=1 +#SBATCH --time=72:00:00 + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export NCCL_IB_DISABLE=1 # Our cluster does not have InfiniBand. We need to disable usage using this flag. +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +# export CUDA_LAUNCH_BLOCKING=1 +export MASTER_ADDR="$(hostname --fqdn)" +export MASTER_PORT="$(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1])')" +export RDVZ_ID=$RANDOM +echo "RDZV Endpoint $MASTER_ADDR:$MASTER_PORT" + +echo $MASTER_ADDR +echo $MASTER_PORT + +export PYTHONPATH="." +nvidia-smi + +python simsiam/linear_eval_original_code.py \ + --data="/scratch/ssd004/datasets/imagenet256" \ + --arch="resnet50" \ + --multiprocessing-distributed \ + --lars --batch-size=2048 \ + --epochs=100 \ + -j=16 \ + --world-size 1 \ + --rank 0 \ + --pretrained="/ssd003/projects/aieng/genssl/swav_pretrained.pth.tar" \ + --dist-url "tcp://$MASTER_ADDR:$MASTER_PORT" \ + --ablation_mode="icgan" diff --git a/original_eval_scripts/places365/baseline.slrm b/original_eval_scripts/places365/baseline.slrm new file mode 100644 index 0000000..b5854a7 --- /dev/null +++ b/original_eval_scripts/places365/baseline.slrm @@ -0,0 +1,48 @@ +#!/bin/bash + +#SBATCH --job-name="places365" +#SBATCH --partition=rtx6000 +#SBATCH --qos=deadline +#SBATCH --account=deadline +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=32 +#SBATCH --mem=0 +#SBATCH --output=places365_baseline_160_%j.out +#SBATCH --error=places365_baseline_160_%j.err +#SBATCH --open-mode=append +#SBATCH --wait-all-nodes=1 +#SBATCH --time=72:00:00 + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export NCCL_IB_DISABLE=1 # Our cluster does not have InfiniBand. We need to disable usage using this flag. +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +# export CUDA_LAUNCH_BLOCKING=1 +export MASTER_ADDR="$(hostname --fqdn)" +export MASTER_PORT="$(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1])')" +export RDVZ_ID=$RANDOM +echo "RDZV Endpoint $MASTER_ADDR:$MASTER_PORT" + +echo $MASTER_ADDR +echo $MASTER_PORT + +export PYTHONPATH="." +nvidia-smi + +python simsiam/linear_eval_original_code.py \ + --data="/projects/imagenet_synthetic/fereshteh_datasets/places365" \ + --arch="resnet50" \ + --multiprocessing-distributed \ + --lars \ + --batch-size=4096 \ + --epochs=100 \ + -j=16 \ + --world-size 1 \ + --rank 0 \ + --pretrained="/projects/imagenet_synthetic/model_checkpoints/simsiam_baseline_seed43_bs128_rforig_2024-03-05-12-27/checkpoint_0160.pth.tar"\ + --dist-url "tcp://$MASTER_ADDR:$MASTER_PORT" \ + --dataset_name="places365" \ + --num_classes=434 \ No newline at end of file diff --git a/original_eval_scripts/places365/icgan.slrm b/original_eval_scripts/places365/icgan.slrm new file mode 100644 index 0000000..841bbe9 --- /dev/null +++ b/original_eval_scripts/places365/icgan.slrm @@ -0,0 +1,47 @@ +#!/bin/bash + +#SBATCH --job-name="places365" +#SBATCH --partition=a40 +#SBATCH --qos=deadline +#SBATCH --account=deadline +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=32 +#SBATCH --mem=0 +#SBATCH --output=places365_icgan_160_%j.out +#SBATCH --error=places365_icgan_160_%j.err +#SBATCH --open-mode=append +#SBATCH --wait-all-nodes=1 +#SBATCH --time=72:00:00 + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export NCCL_IB_DISABLE=1 # Our cluster does not have InfiniBand. We need to disable usage using this flag. +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +# export CUDA_LAUNCH_BLOCKING=1 +export MASTER_ADDR="$(hostname --fqdn)" +export MASTER_PORT="$(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1])')" +export RDVZ_ID=$RANDOM +echo "RDZV Endpoint $MASTER_ADDR:$MASTER_PORT" + +echo $MASTER_ADDR +echo $MASTER_PORT + +export PYTHONPATH="." +nvidia-smi + +python simsiam/linear_eval_original_code.py \ + --data="/projects/imagenet_synthetic/fereshteh_datasets/places365" \ + --arch="resnet50" \ + --multiprocessing-distributed \ + --lars \ + --batch-size=4096 \ + --epochs=100 \ + -j=16 \ + --world-size 1 \ + --rank 0 \ + --pretrained="/projects/imagenet_synthetic/model_checkpoints/simsiam_icgan_seed43_bs128_rforig_2024-03-05-12-52/checkpoint_0160.pth.tar"\ + --dist-url "tcp://$MASTER_ADDR:$MASTER_PORT" \ + --dataset_name="places365" \ No newline at end of file diff --git a/original_eval_scripts/places365/icgan_ab.slrm b/original_eval_scripts/places365/icgan_ab.slrm new file mode 100644 index 0000000..486933d --- /dev/null +++ b/original_eval_scripts/places365/icgan_ab.slrm @@ -0,0 +1,49 @@ +#!/bin/bash + +#SBATCH --job-name="places365" +#SBATCH --partition=rtx6000 +#SBATCH --qos=deadline +#SBATCH --account=deadline +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=32 +#SBATCH --mem=0 +#SBATCH --output=icgan_ab_places365_%j.out +#SBATCH --error=icgan_ab_places365_%j.err +#SBATCH --open-mode=append +#SBATCH --wait-all-nodes=1 +#SBATCH --time=72:00:00 + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export NCCL_IB_DISABLE=1 # Our cluster does not have InfiniBand. We need to disable usage using this flag. +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +# export CUDA_LAUNCH_BLOCKING=1 +export MASTER_ADDR="$(hostname --fqdn)" +export MASTER_PORT="$(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1])')" +export RDVZ_ID=$RANDOM +echo "RDZV Endpoint $MASTER_ADDR:$MASTER_PORT" + +echo $MASTER_ADDR +echo $MASTER_PORT + +export PYTHONPATH="." +nvidia-smi + +python simsiam/linear_eval_original_code.py \ + --data="/projects/imagenet_synthetic/fereshteh_datasets/places365" \ + --arch="resnet50" \ + --multiprocessing-distributed \ + --lars \ + --batch-size=4096 \ + --epochs=100 \ + -j=16 \ + --world-size 1 \ + --rank 0 \ + --pretrained="/ssd003/projects/aieng/genssl/swav_pretrained.pth.tar" \ + --dist-url "tcp://$MASTER_ADDR:$MASTER_PORT" \ + --dataset_name="places365" \ + --num_classes=434 \ + --ablation_mode="icgan" \ No newline at end of file diff --git a/original_eval_simsiam.slrm b/original_eval_simsiam.slrm new file mode 100644 index 0000000..a9b6e1e --- /dev/null +++ b/original_eval_simsiam.slrm @@ -0,0 +1,46 @@ +#!/bin/bash + +#SBATCH --job-name="sana_eval" +#SBATCH --partition=t4v2 + +#SBATCH --qos=deadline +#SBATCH --account=deadline +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=32 +#SBATCH --mem=0 +#SBATCH --output=singlenode_stablediff_160_%j.out +#SBATCH --error=singlenode_stablediff_160_%j.err +#SBATCH --open-mode=append +#SBATCH --wait-all-nodes=1 +#SBATCH --time=72:00:00 + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export NCCL_IB_DISABLE=1 # Our cluster does not have InfiniBand. We need to disable usage using this flag. +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +# export CUDA_LAUNCH_BLOCKING=1 +export MASTER_ADDR="$(hostname --fqdn)" +export MASTER_PORT="$(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1])')" +export RDVZ_ID=$RANDOM +echo "RDZV Endpoint $MASTER_ADDR:$MASTER_PORT" + +echo $MASTER_ADDR +echo $MASTER_PORT + +export PYTHONPATH="." +nvidia-smi + +python simsiam/linear_eval_original_code.py \ + --data="/scratch/ssd004/datasets/imagenet256" \ + --arch="resnet50" \ + --multiprocessing-distributed \ + --lars --batch-size=2048 \ + --epochs=100 \ + -j=16 \ + --world-size 1 \ + --rank 0 \ + --pretrained="/projects/imagenet_synthetic/model_checkpoints/simsiam_stablediff_p0p5_seed43_2024-03-05-13-39/checkpoint_0160.pth.tar" \ + --dist-url "tcp://$MASTER_ADDR:$MASTER_PORT" diff --git a/run_simCLR.py b/run_simCLR.py index b7ee380..072b73b 100644 --- a/run_simCLR.py +++ b/run_simCLR.py @@ -177,6 +177,17 @@ type=int, help="Synthetic data files are named filename_i.JPEG. This index determines the upper bound for i.", ) +parser.add_argument( + "--last_checkpoint", + default="", + help="Last model checkpoint file to resume training from.", +) +parser.add_argument( + "--generative_augmentation_prob", + default=None, + type=float, + help="The probability of applying a generative model augmentation to a view. Applies to the views separately.", +) def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int) -> None: @@ -231,6 +242,7 @@ def main(): args.synthetic_data_dir, index_min=args.synthetic_index_min, index_max=args.synthetic_index_max, + generative_augmentation_prob=args.generative_augmentation_prob, ) else: print(f"Using real data for training at {args.data}.") diff --git a/simsiam/LARC.py b/simsiam/LARC.py new file mode 100644 index 0000000..fe41b13 --- /dev/null +++ b/simsiam/LARC.py @@ -0,0 +1,107 @@ +import torch +from torch import nn +from torch.nn.parameter import Parameter + + +class LARC(object): + """ + :class:`LARC` is a pytorch implementation of both the scaling and clipping variants of LARC, + in which the ratio between gradient and parameter magnitudes is used to calculate an adaptive + local learning rate for each individual parameter. The algorithm is designed to improve + convergence of large batch training. + + See https://arxiv.org/abs/1708.03888 for calculation of the local learning rate. + In practice it modifies the gradients of parameters as a proxy for modifying the learning rate + of the parameters. This design allows it to be used as a wrapper around any torch.optim Optimizer. + ``` + model = ... + optim = torch.optim.Adam(model.parameters(), lr=...) + optim = LARC(optim) + ``` + It can even be used in conjunction with apex.fp16_utils.FP16_optimizer. + ``` + model = ... + optim = torch.optim.Adam(model.parameters(), lr=...) + optim = LARC(optim) + optim = apex.fp16_utils.FP16_Optimizer(optim) + ``` + Args: + optimizer: Pytorch optimizer to wrap and modify learning rate for. + trust_coefficient: Trust coefficient for calculating the lr. See https://arxiv.org/abs/1708.03888 + clip: Decides between clipping or scaling mode of LARC. If `clip=True` the learning rate is set to `min(optimizer_lr, local_lr)` for each parameter. If `clip=False` the learning rate is set to `local_lr*optimizer_lr`. + eps: epsilon kludge to help with numerical stability while calculating adaptive_lr + """ + + def __init__(self, optimizer, trust_coefficient=0.02, clip=True, eps=1e-8): + self.optim = optimizer + self.trust_coefficient = trust_coefficient + self.eps = eps + self.clip = clip + + def __getstate__(self): + return self.optim.__getstate__() + + def __setstate__(self, state): + self.optim.__setstate__(state) + + @property + def state(self): + return self.optim.state + + def __repr__(self): + return self.optim.__repr__() + + @property + def param_groups(self): + return self.optim.param_groups + + @param_groups.setter + def param_groups(self, value): + self.optim.param_groups = value + + def state_dict(self): + return self.optim.state_dict() + + def load_state_dict(self, state_dict): + self.optim.load_state_dict(state_dict) + + def zero_grad(self): + self.optim.zero_grad() + + def add_param_group(self, param_group): + self.optim.add_param_group(param_group) + + def step(self): + with torch.no_grad(): + weight_decays = [] + for group in self.optim.param_groups: + # absorb weight decay control from optimizer + weight_decay = group["weight_decay"] if "weight_decay" in group else 0 + weight_decays.append(weight_decay) + group["weight_decay"] = 0 + for p in group["params"]: + if p.grad is None: + continue + param_norm = torch.norm(p.data) + grad_norm = torch.norm(p.grad.data) + + if param_norm != 0 and grad_norm != 0: + # calculate adaptive lr + weight decay + adaptive_lr = ( + self.trust_coefficient + * (param_norm) + / (grad_norm + param_norm * weight_decay + self.eps) + ) + + # clip learning rate for LARC + if self.clip: + # calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)` + adaptive_lr = min(adaptive_lr / group["lr"], 1) + + p.grad.data += weight_decay * p.data + p.grad.data *= adaptive_lr + + self.optim.step() + # return weight decay control to optimizer + for i, group in enumerate(self.optim.param_groups): + group["weight_decay"] = weight_decays[i] diff --git a/simsiam/LICENSE b/simsiam/LICENSE new file mode 100644 index 0000000..105a4fb --- /dev/null +++ b/simsiam/LICENSE @@ -0,0 +1,399 @@ +Attribution-NonCommercial 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More_considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial 4.0 International Public +License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial 4.0 International Public License ("Public +License"). To the extent this Public License may be interpreted as a +contract, You are granted the Licensed Rights in consideration of Your +acceptance of these terms and conditions, and the Licensor grants You +such rights in consideration of benefits the Licensor receives from +making the Licensed Material available under these terms and +conditions. + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + d. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + e. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + f. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + g. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + h. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + i. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + j. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + k. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + l. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + 4. If You Share Adapted Material You produce, the Adapter's + License You apply must not prevent recipients of the Adapted + Material from complying with this Public License. + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material; and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. \ No newline at end of file diff --git a/simsiam/README.md b/simsiam/README.md new file mode 100644 index 0000000..47bab1b --- /dev/null +++ b/simsiam/README.md @@ -0,0 +1,96 @@ +# SimSiam: Exploring Simple Siamese Representation Learning + +

+ simsiam +

+ +This is a PyTorch implementation of the [SimSiam paper](https://arxiv.org/abs/2011.10566): +``` +@Article{chen2020simsiam, + author = {Xinlei Chen and Kaiming He}, + title = {Exploring Simple Siamese Representation Learning}, + journal = {arXiv preprint arXiv:2011.10566}, + year = {2020}, +} +``` + +### Preparation + +Install PyTorch and download the ImageNet dataset following the [official PyTorch ImageNet training code](https://github.com/pytorch/examples/tree/master/imagenet). Similar to [MoCo](https://github.com/facebookresearch/moco), the code release contains minimal modifications for both unsupervised pre-training and linear classification to that code. + +In addition, install [apex](https://github.com/NVIDIA/apex) for the [LARS](https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py) implementation needed for linear classification. + +### Unsupervised Pre-Training + +Only **multi-gpu**, **DistributedDataParallel** training is supported; single-gpu or DataParallel training is not supported. + +To do unsupervised pre-training of a ResNet-50 model on ImageNet in an 8-gpu machine, run: +``` +python main_simsiam.py \ + -a resnet50 \ + --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 \ + --fix-pred-lr \ + [your imagenet-folder with train and val folders] +``` +The script uses all the default hyper-parameters as described in the paper, and uses the default augmentation recipe from [MoCo v2](https://arxiv.org/abs/2003.04297). + +The above command performs pre-training with a non-decaying predictor learning rate for 100 epochs, corresponding to the last row of Table 1 in the paper. + +### Linear Classification + +With a pre-trained model, to train a supervised linear classifier on frozen features/weights in an 8-gpu machine, run: +``` +python main_lincls.py \ + -a resnet50 \ + --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 \ + --pretrained [your checkpoint path]/checkpoint_0099.pth.tar \ + --lars \ + [your imagenet-folder with train and val folders] +``` + +The above command uses LARS optimizer and a default batch size of 4096. + +### Models and Logs + +Our pre-trained ResNet-50 models and logs: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
pre-train
epochs
batch
size
pre-train
ckpt
pre-train
log
linear cls.
ckpt
linear cls.
log
top-1 acc.
100512linklinklinklink68.1
100256linklinklinklink68.3
+ +Settings for the above: 8 NVIDIA V100 GPUs, CUDA 10.1/CuDNN 7.6.5, PyTorch 1.7.0. + +### Transferring to Object Detection + +Same as [MoCo](https://github.com/facebookresearch/moco) for object detection transfer, please see [moco/detection](https://github.com/facebookresearch/moco/tree/master/detection). + + +### License + +This project is under the CC-BY-NC 4.0 license. See [LICENSE](LICENSE) for details. \ No newline at end of file diff --git a/simsiam/__init__.py b/simsiam/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/simsiam/adil_linear_eval.py b/simsiam/adil_linear_eval.py new file mode 100644 index 0000000..3911691 --- /dev/null +++ b/simsiam/adil_linear_eval.py @@ -0,0 +1,464 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import math +import os +import random +import shutil +from functools import partial + +import torch +import torch.nn.parallel +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +from torch import nn +from torch.backends import cudnn +from torch.nn.parallel import DistributedDataParallel as DDP # noqa: N817 +from torchvision import datasets, models, transforms +from tqdm import tqdm + +from SimCLR import distributed as dist_utils +from torch import distributed as dist + + +model_names = sorted( + name + for name in models.__dict__ + if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) +) + +parser = argparse.ArgumentParser(description="PyTorch ImageNet Training") +parser.add_argument( + "--data_dir", + metavar="DIR", + default="/scratch/ssd004/datasets/imagenet256", + help="path to dataset.", +) +parser.add_argument( + "-a", + "--arch", + metavar="ARCH", + default="resnet50", + choices=model_names, + help="model architecture: " + " | ".join(model_names) + " (default: resnet50)", +) +parser.add_argument( + "-j", + "--num_workers", + default=4, + type=int, + metavar="N", + help="number of data loading workers (default: 32)", +) +parser.add_argument( + "--epochs", default=90, type=int, metavar="N", help="number of total epochs to run" +) +parser.add_argument( + "-b", + "--batch-size", + default=4096, + type=int, + metavar="N", + help="mini-batch size (default: 4096), this is the total " + "batch size of all GPUs on the current node when " + "using Data Parallel or Distributed Data Parallel", +) +parser.add_argument( + "--lr", + "--learning-rate", + default=0.1, + type=float, + metavar="LR", + help="initial (base) learning rate", + dest="lr", +) +parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") +parser.add_argument( + "--wd", + "--weight-decay", + default=0.0, + type=float, + metavar="W", + help="weight decay (default: 0.)", + dest="weight_decay", +) +parser.add_argument( + "--distributed_mode", + action="store_true", + help="Enable distributed training", +) +parser.add_argument("--distributed_launcher", default="slurm") +parser.add_argument("--distributed_backend", default="nccl") +parser.add_argument( + "--seed", default=42, type=int, help="seed for initializing training. " +) +parser.add_argument( + "--pretrained_checkpoint", + default="", + type=str, + help="Path to simsiam pretrained checkpoint.", +) +parser.add_argument("--lars", action="store_true", help="Use LARS") +parser.add_argument( + "--checkpoint_dir", + default="", + help="Checkpoint directory to save eval model checkpoints.", +) + + +best_acc1 = 0 + + +def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int) -> None: + """Initialize worker processes with a random seed. + + Parameters + ---------- + worker_id : int + ID of the worker process. + num_workers : int + Total number of workers that will be initialized. + rank : int + The rank of the current process. + seed : int + A random seed used determine the worker seed. + """ + worker_seed = num_workers * rank + worker_id + seed + torch.manual_seed(worker_seed) + random.seed(worker_seed) + +def setup() -> None: + """Initialize the process group.""" + dist.init_process_group("nccl") + + +def cleanup() -> None: + """Clean up the process group after training.""" + dist.destroy_process_group() + + +def main(): + args = parser.parse_args() + global best_acc1 + + # torch.multiprocessing.set_start_method("spawn") + if args.distributed_mode: + # dist_utils.init_distributed_mode( + # launcher=args.distributed_launcher, + # backend=args.distributed_backend, + # ) + setup() + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + torch.cuda.empty_cache() + device_id = torch.cuda.current_device() + else: + device_id = None + + # create model + print(f"Creating model {args.arch}") + model = models.__dict__[args.arch]() + + # freeze all layers but the last fc + for name, param in model.named_parameters(): + if name not in ["fc.weight", "fc.bias"]: + param.requires_grad = False + # init the fc layer + model.fc.weight.data.normal_(mean=0.0, std=0.01) + model.fc.bias.data.zero_() + + # load from pre-trained, before DistributedDataParallel constructor + if args.pretrained_checkpoint: + if os.path.isfile(args.pretrained_checkpoint): + print(f"Loading checkpoint {args.pretrained_checkpoint}") + checkpoint = torch.load(args.pretrained_checkpoint, map_location="cpu") + + # rename moco pre-trained keys + state_dict = checkpoint["state_dict"] + for k in list(state_dict.keys()): + # retain only encoder up to before the embedding layer + if k.startswith("module.encoder") and not k.startswith( + "module.encoder.fc" + ): + # remove prefix + state_dict[k[len("module.encoder.") :]] = state_dict[k] + # delete renamed or unused k + del state_dict[k] + + msg = model.load_state_dict(state_dict, strict=False) + assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} + else: + raise ValueError(f"No checkpoint found at: {args.pretrained_checkpoint}") + + # infer learning rate before changing batch size + init_lr = args.lr * args.batch_size / 256 + + if args.distributed_mode and dist_utils.is_dist_avail_and_initialized(): + # torch.cuda.set_device(device_id) + model = model.cuda(device_id) + model = DDP(model, device_ids=[device_id]) + else: + raise NotImplementedError("Only DistributedDataParallel is supported.") + + # define loss function (criterion) and optimizer + criterion = nn.CrossEntropyLoss().cuda(device_id) + + # optimize only the linear classifier + parameters = list(filter(lambda p: p.requires_grad, model.parameters())) + assert len(parameters) == 2 # fc.weight, fc.bias + + optimizer = torch.optim.SGD( + parameters, init_lr, momentum=args.momentum, weight_decay=args.weight_decay + ) + if args.lars: + print("Use LARS optimizer.") + # from apex.parallel.LARC import LARC + from LARC import LARC + + optimizer = LARC(optimizer=optimizer, trust_coefficient=0.001, clip=False) + + cudnn.benchmark = True + + # Data loading code + train_dir = os.path.join(args.data_dir, "train") + val_dir = os.path.join(args.data_dir, "val") + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + train_dataset = datasets.ImageFolder( + train_dir, + transforms.Compose( + [ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ], + ), + ) + + if dist_utils.is_dist_avail_and_initialized() and args.distributed_mode: + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, + seed=args.seed, + ) + else: + train_sampler = None + + init_fn = partial( + worker_init_fn, + num_workers=args.num_workers, + rank=dist_utils.get_rank(), + seed=args.seed, + ) + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.batch_size, + shuffle=(train_sampler is None), + sampler=train_sampler, + num_workers=args.num_workers, + worker_init_fn=init_fn, + pin_memory=True, # TODO(arashaf): this was set to false in training script. + ) + + val_loader = torch.utils.data.DataLoader( + datasets.ImageFolder( + val_dir, + transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ], + ), + ), + batch_size=256, + shuffle=False, + num_workers=args.num_workers, + pin_memory=True, + ) + + for epoch in tqdm(range(args.epochs)): + print(f"Starting training epoch: {epoch}") + if dist_utils.is_dist_avail_and_initialized(): + train_sampler.set_epoch(epoch) + adjust_learning_rate(optimizer, init_lr, epoch, args) + + # train for one epoch + train(train_loader, model, criterion, optimizer, epoch, device_id, args) + + # evaluate on validation set + acc1 = validate(val_loader, model, criterion, device_id, args) + + # remember best acc@1 and save checkpoint + is_best = acc1 > best_acc1 + best_acc1 = max(acc1, best_acc1) + + if args.checkpoint_dir and dist_utils.get_rank() == 0: + os.makedirs(args.checkpoint_dir, exist_ok=True) + checkpoint_name = "eval_checkpoint_{:04d}.pth.tar".format(epoch) + checkpoint_file = os.path.join(args.checkpoint_dir, checkpoint_name) + save_checkpoint( + { + "epoch": epoch, + "arch": args.arch, + "state_dict": model.state_dict(), + "best_acc1": best_acc1, + "optimizer": optimizer.state_dict(), + }, + is_best, + checkpoint_file, + ) + if epoch == 0: + sanity_check(model.state_dict(), args.pretrained_checkpoint) + + +def train(train_loader, model, criterion, optimizer, epoch, device_id, args): + """ + Switch to eval mode: + Under the protocol of linear classification on frozen features/models, + it is not legitimate to change any part of the pre-trained model. + BatchNorm in train mode may revise running mean/std (even if it receives + no gradient), which are part of the model parameters too. + """ + model.eval() + + for images, target in tqdm(train_loader): + images = images.cuda(device_id, non_blocking=True) + target = target.cuda(device_id, non_blocking=True) + + # compute output + output = model(images) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + optimizer.step() + + +def validate(val_loader, model, criterion, device_id, args): + top1 = AverageMeter("Acc@1", ":6.2f") + top5 = AverageMeter("Acc@5", ":6.2f") + + # switch to evaluate mode + model.eval() + + with torch.no_grad(): + for images, target in tqdm(val_loader): + images = images.cuda(device_id, non_blocking=True) + target = target.cuda(device_id, non_blocking=True) + + # compute output + output = model(images) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + top1.update(acc1[0], images.size(0)) + top5.update(acc5[0], images.size(0)) + + print( + "Validation Accuracy@1 {top1.avg:.3f}, Accuracy@5 {top5.avg:.3f}".format( + top1=top1, top5=top5 + ) + ) + + return top1.avg + + +def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"): + print(f"Saving checkpoint at: {filename}") + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, "model_best.pth.tar") + + +def sanity_check(state_dict, pretrained_weights): + """ + Linear classifier should not change any weights other than the linear layer. + This sanity check asserts nothing wrong happens (e.g., BN stats updated). + """ + print(f"Loading {pretrained_weights} for sanity check") + checkpoint = torch.load(pretrained_weights, map_location="cpu") + state_dict_pre = checkpoint["state_dict"] + + for k in list(state_dict.keys()): + # only ignore fc layer + if "fc.weight" in k or "fc.bias" in k: + continue + + # name in pretrained model + k_pre = ( + "module.encoder." + k[len("module.") :] + if k.startswith("module.") + else "module.encoder." + k + ) + + assert ( + state_dict[k].cpu() == state_dict_pre[k_pre] + ).all(), "{} is changed in linear classifier training.".format(k) + + print("Sanity check passed.") + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=":f"): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" + return fmtstr.format(**self.__dict__) + + +def adjust_learning_rate(optimizer, init_lr, epoch, args): + """Decay the learning rate based on schedule""" + cur_lr = init_lr * 0.5 * (1.0 + math.cos(math.pi * epoch / args.epochs)) + for param_group in optimizer.param_groups: + param_group["lr"] = cur_lr + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +if __name__ == "__main__": + main() diff --git a/simsiam/adil_linear_eval_original_logs.py b/simsiam/adil_linear_eval_original_logs.py new file mode 100644 index 0000000..537040b --- /dev/null +++ b/simsiam/adil_linear_eval_original_logs.py @@ -0,0 +1,525 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import argparse +import math +import os +import random +import shutil +import time +from functools import partial + +import torch +import torch.nn.parallel +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +from torch import nn +from torch.backends import cudnn +from torch.nn.parallel import DistributedDataParallel as DDP # noqa: N817 +from torchvision import datasets, models, transforms +from tqdm import tqdm + +from SimCLR import distributed as dist_utils +from LARC import LARC +from torch import distributed as dist + +model_names = sorted( + name + for name in models.__dict__ + if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) +) + +parser = argparse.ArgumentParser(description="PyTorch ImageNet Training") +parser.add_argument( + "--data_dir", + metavar="DIR", + default="/scratch/ssd004/datasets/imagenet256", + help="path to dataset.", +) +parser.add_argument( + "-a", + "--arch", + metavar="ARCH", + default="resnet50", + choices=model_names, + help="model architecture: " + " | ".join(model_names) + " (default: resnet50)", +) +parser.add_argument( + "-j", + "--num_workers", + default=4, + type=int, + metavar="N", + help="number of data loading workers (default: 32)", +) +parser.add_argument( + "--epochs", default=90, type=int, metavar="N", help="number of total epochs to run" +) +parser.add_argument( + "-b", + "--batch-size", + default=4096, + type=int, + metavar="N", + help="mini-batch size (default: 4096), this is the total " + "batch size of all GPUs on the current node when " + "using Data Parallel or Distributed Data Parallel", +) +parser.add_argument( + "--lr", + "--learning-rate", + default=0.1, + type=float, + metavar="LR", + help="initial (base) learning rate", + dest="lr", +) +parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") +parser.add_argument( + "--wd", + "--weight-decay", + default=0.0, + type=float, + metavar="W", + help="weight decay (default: 0.)", + dest="weight_decay", +) +parser.add_argument( + "--distributed_mode", + action="store_true", + help="Enable distributed training", +) +parser.add_argument("--distributed_launcher", default="slurm") +parser.add_argument("--distributed_backend", default="nccl") +parser.add_argument( + "--seed", default=42, type=int, help="seed for initializing training. " +) +parser.add_argument( + "--pretrained_checkpoint", + default="", + type=str, + help="Path to simsiam pretrained checkpoint.", +) +parser.add_argument("--lars", action="store_true", help="Use LARS") +parser.add_argument( + "--checkpoint_dir", + default="", + help="Checkpoint directory to save eval model checkpoints.", +) +parser.add_argument( + "-p", + "--print-freq", + default=10, + type=int, + metavar="N", + help="print frequency (default: 10)", +) + +best_acc1 = 0 + + +def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int) -> None: + """Initialize worker processes with a random seed. + + Parameters + ---------- + worker_id : int + ID of the worker process. + num_workers : int + Total number of workers that will be initialized. + rank : int + The rank of the current process. + seed : int + A random seed used determine the worker seed. + """ + worker_seed = num_workers * rank + worker_id + seed + torch.manual_seed(worker_seed) + random.seed(worker_seed) + +def setup() -> None: + """Initialize the process group.""" + dist.init_process_group("nccl") + + +def cleanup() -> None: + """Clean up the process group after training.""" + dist.destroy_process_group() + + +def main(): + args = parser.parse_args() + global best_acc1 + + # torch.multiprocessing.set_start_method("spawn") + if args.distributed_mode: + setup() + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + torch.cuda.empty_cache() + device_id = torch.cuda.current_device() + else: + device_id = None + + # create model + print(f"Creating model {args.arch}") + model = models.__dict__[args.arch]() + + # freeze all layers but the last fc + for name, param in model.named_parameters(): + if name not in ["fc.weight", "fc.bias"]: + param.requires_grad = False + # init the fc layer + model.fc.weight.data.normal_(mean=0.0, std=0.01) + model.fc.bias.data.zero_() + + # load from pre-trained, before DistributedDataParallel constructor + if args.pretrained_checkpoint: + if os.path.isfile(args.pretrained_checkpoint): + print(f"Loading checkpoint {args.pretrained_checkpoint}") + checkpoint = torch.load(args.pretrained_checkpoint, map_location="cpu") + + # rename moco pre-trained keys + state_dict = checkpoint["state_dict"] + for k in list(state_dict.keys()): + # retain only encoder up to before the embedding layer + if k.startswith("module.encoder") and not k.startswith( + "module.encoder.fc" + ): + # remove prefix + state_dict[k[len("module.encoder.") :]] = state_dict[k] + # delete renamed or unused k + del state_dict[k] + + msg = model.load_state_dict(state_dict, strict=False) + assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} + else: + raise ValueError(f"No checkpoint found at: {args.pretrained_checkpoint}") + + # infer learning rate before changing batch size + init_lr = args.lr * args.batch_size / 256 + + if args.distributed_mode and dist_utils.is_dist_avail_and_initialized(): + # torch.cuda.set_device(device_id) + model = model.cuda(device_id) + model = DDP(model, device_ids=[device_id]) + else: + raise NotImplementedError("Only DistributedDataParallel is supported.") + + # define loss function (criterion) and optimizer + criterion = nn.CrossEntropyLoss().cuda(device_id) + + # optimize only the linear classifier + parameters = list(filter(lambda p: p.requires_grad, model.parameters())) + assert len(parameters) == 2 # fc.weight, fc.bias + + # TODO(arashaf): Enable Adam optimizer + optimizer = torch.optim.SGD( + parameters, + init_lr, + momentum=args.momentum, + weight_decay=args.weight_decay, + ) + if args.lars: + print("Use LARS optimizer.") + LARC(optimizer=optimizer, trust_coefficient=0.001, clip=False) + + cudnn.benchmark = True + + # Data loading code + train_dir = os.path.join(args.data_dir, "train") + val_dir = os.path.join(args.data_dir, "val") + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + train_dataset = datasets.ImageFolder( + train_dir, + transforms.Compose( + [ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ], + ), + ) + + if dist_utils.is_dist_avail_and_initialized() and args.distributed_mode: + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, + seed=args.seed, + ) + else: + train_sampler = None + + init_fn = partial( + worker_init_fn, + num_workers=args.num_workers, + rank=dist_utils.get_rank(), + seed=args.seed, + ) + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.batch_size, + shuffle=(train_sampler is None), + sampler=train_sampler, + num_workers=args.num_workers, + worker_init_fn=init_fn, + pin_memory=True, # TODO(arashaf): this was set to false in training script. + ) + + val_dataset = datasets.ImageFolder( + val_dir, + transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ], + ), + ) + val_loader = torch.utils.data.DataLoader( + val_dataset, + batch_size=256, + shuffle=False, + num_workers=args.num_workers, + pin_memory=True, + ) + + for epoch in tqdm(range(args.epochs)): + print(f"Starting training epoch: {epoch}") + if dist_utils.is_dist_avail_and_initialized(): + train_sampler.set_epoch(epoch) + adjust_learning_rate(optimizer, init_lr, epoch, args) + + # train for one epoch + train(train_loader, model, criterion, optimizer, epoch, device_id, args) + + # evaluate on validation set + acc1 = validate(val_loader, model, criterion, device_id, args) + + # remember best acc@1 and save checkpoint + is_best = acc1 > best_acc1 + best_acc1 = max(acc1, best_acc1) + + if args.checkpoint_dir and dist_utils.get_rank() == 0: + os.makedirs(args.checkpoint_dir, exist_ok=True) + checkpoint_name = "eval_checkpoint_{:04d}.pth.tar".format(epoch) + checkpoint_file = os.path.join(args.checkpoint_dir, checkpoint_name) + save_checkpoint( + { + "epoch": epoch, + "arch": args.arch, + "state_dict": model.state_dict(), + "best_acc1": best_acc1, + "optimizer": optimizer.state_dict(), + }, + is_best, + checkpoint_file, + ) + if epoch == 0: + sanity_check(model.state_dict(), args.pretrained_checkpoint) + + +def train(train_loader, model, criterion, optimizer, epoch, device_id, args): + batch_time = AverageMeter("Time", ":6.3f") + data_time = AverageMeter("Data", ":6.3f") + losses = AverageMeter("Loss", ":.4e") + top1 = AverageMeter("Acc@1", ":6.2f") + top5 = AverageMeter("Acc@5", ":6.2f") + progress = ProgressMeter( + len(train_loader), + [batch_time, data_time, losses, top1, top5], + prefix="Epoch: [{}]".format(epoch), + ) + """ + Switch to eval mode: + Under the protocol of linear classification on frozen features/models, + it is not legitimate to change any part of the pre-trained model. + BatchNorm in train mode may revise running mean/std (even if it receives + no gradient), which are part of the model parameters too. + """ + model.eval() + + end = time.time() + for i, (images, target) in enumerate(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + images = images.cuda(device_id, non_blocking=True) + target = target.cuda(device_id, non_blocking=True) + + # compute output + output = model(images) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), images.size(0)) + top1.update(acc1[0], images.size(0)) + top5.update(acc5[0], images.size(0)) + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + progress.display(i) + + +def validate(val_loader, model, criterion, device_id, args): + batch_time = AverageMeter("Time", ":6.3f") + losses = AverageMeter("Loss", ":.4e") + top1 = AverageMeter("Acc@1", ":6.2f") + top5 = AverageMeter("Acc@5", ":6.2f") + progress = ProgressMeter( + len(val_loader), [batch_time, losses, top1, top5], prefix="Test: " + ) + + # switch to evaluate mode + model.eval() + + with torch.no_grad(): + end = time.time() + for i, (images, target) in enumerate(val_loader): + images = images.cuda(device_id, non_blocking=True) + target = target.cuda(device_id, non_blocking=True) + + # compute output + output = model(images) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), images.size(0)) + top1.update(acc1[0], images.size(0)) + top5.update(acc5[0], images.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + progress.display(i) + + print( + "Validation Accuracy@1 {top1.avg:.3f}, Accuracy@5 {top5.avg:.3f}".format( + top1=top1, + top5=top5, + ) + ) + + return top1.avg + + +def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"): + print(f"Saving checkpoint at: {filename}") + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, "model_best.pth.tar") + + +def sanity_check(state_dict, pretrained_weights): + """ + Linear classifier should not change any weights other than the linear layer. + This sanity check asserts nothing wrong happens (e.g., BN stats updated). + """ + print(f"Loading {pretrained_weights} for sanity check") + checkpoint = torch.load(pretrained_weights, map_location="cpu") + state_dict_pre = checkpoint["state_dict"] + + for k in list(state_dict.keys()): + # only ignore fc layer + if "fc.weight" in k or "fc.bias" in k: + continue + + # name in pretrained model + k_pre = ( + "module.encoder." + k[len("module.") :] + if k.startswith("module.") + else "module.encoder." + k + ) + + assert ( + state_dict[k].cpu() == state_dict_pre[k_pre] + ).all(), "{} is changed in linear classifier training.".format(k) + + print("Sanity check passed.") + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=":f"): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" + return fmtstr.format(**self.__dict__) + + +class ProgressMeter(object): + def __init__(self, num_batches, meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def display(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print("\t".join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = "{:" + str(num_digits) + "d}" + return "[" + fmt + "/" + fmt.format(num_batches) + "]" + + +def adjust_learning_rate(optimizer, init_lr, epoch, args): + """Decay the learning rate based on schedule""" + cur_lr = init_lr * 0.5 * (1.0 + math.cos(math.pi * epoch / args.epochs)) + for param_group in optimizer.param_groups: + param_group["lr"] = cur_lr + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +if __name__ == "__main__": + main() diff --git a/simsiam/adil_main_simsiam.py b/simsiam/adil_main_simsiam.py new file mode 100644 index 0000000..bdd3de9 --- /dev/null +++ b/simsiam/adil_main_simsiam.py @@ -0,0 +1,438 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import math +import os +import random +from datetime import datetime +from functools import partial + +import torch +import torch.nn.parallel +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +from torch import distributed as dist +from torch import nn +from torch.backends import cudnn +from torch.nn.parallel import DistributedDataParallel as DDP # noqa: N817 +from torch.utils.data.distributed import DistributedSampler +from torchvision import datasets, models +from tqdm import tqdm + +from SimCLR import distributed as dist_utils +from simsiam import builder, loader + + +model_names = sorted( + name + for name in models.__dict__ + if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) +) + +parser = argparse.ArgumentParser(description="PyTorch ImageNet Training") +parser.add_argument( + "--data_dir", + metavar="DIR", + default="/scratch/ssd004/datasets/imagenet256", + help="path to dataset.", +) +parser.add_argument( + "-a", + "--arch", + metavar="ARCH", + default="resnet50", + choices=model_names, + help="model architecture: " + " | ".join(model_names) + " (default: resnet50)", +) +parser.add_argument( + "-j", + "--num_workers", + default=4, + type=int, + metavar="N", + help="number of data loading workers (default: 32)", +) +parser.add_argument( + "--epochs", default=100, type=int, metavar="N", help="number of total epochs to run" +) +parser.add_argument( + "-b", + "--batch-size", + default=256, + type=int, + metavar="N", + help="mini-batch size (default: 512), this is the total " + "batch size of all GPUs on the current node when " + "using Data Parallel or Distributed Data Parallel", +) +parser.add_argument( + "--lr", + "--learning-rate", + default=0.05, + type=float, + metavar="LR", + help="initial (base) learning rate", + dest="lr", +) +parser.add_argument( + "--momentum", default=0.9, type=float, metavar="M", help="momentum of SGD solver" +) +parser.add_argument( + "--wd", + "--weight-decay", + default=1e-4, + type=float, + metavar="W", + help="weight decay (default: 1e-4)", + dest="weight_decay", +) +parser.add_argument( + "--resume_from_checkpoint", + default="", + type=str, + help="Path to latest checkpoint.", +) +parser.add_argument( + "--seed", default=42, type=int, help="seed for initializing training. " +) + +# simsiam specific configs: +parser.add_argument( + "--dim", default=2048, type=int, help="feature dimension (default: 2048)" +) +parser.add_argument( + "--pred-dim", + default=512, + type=int, + help="hidden dimension of the predictor (default: 512)", +) +parser.add_argument( + "--fix-pred-lr", action="store_true", help="Fix learning rate for the predictor" +) + +parser.add_argument( + "--distributed_mode", + action="store_true", + help="Enable distributed training", +) +parser.add_argument("--distributed_launcher", default="slurm") +parser.add_argument("--distributed_backend", default="nccl") +parser.add_argument( + "--checkpoint_dir", + default="/projects/imagenet_synthetic/model_checkpoints", + help="Checkpoint root directory.", +) +parser.add_argument( + "--experiment", + default="", + help="Experiment name.", +) +parser.add_argument( + "--use_synthetic_data", + action=argparse.BooleanOptionalAction, + help="Whether to use real data or synthetic data for training.", +) +parser.add_argument( + "--synthetic_data_dir", + default="/projects/imagenet_synthetic/", + help="Path to the root of synthetic data.", +) +parser.add_argument( + "--synthetic_index_min", + default=0, + type=int, + help="Synthetic data files are named filename_i.JPEG. This index determines the lower bound for i.", +) +parser.add_argument( + "--synthetic_index_max", + default=9, + type=int, + help="Synthetic data files are named filename_i.JPEG. This index determines the upper bound for i.", +) +parser.add_argument( + "--generative_augmentation_prob", + default=None, + type=float, + help="The probability of applying a generative model augmentation to a view. Applies to the views separately.", +) +parser.add_argument( + "-p", + "--print-freq", + default=10, + type=int, + metavar="N", + help="print frequency (default: 10)", +) + + +def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int) -> None: + """Initialize worker processes with a random seed. + + Parameters + ---------- + worker_id : int + ID of the worker process. + num_workers : int + Total number of workers that will be initialized. + rank : int + The rank of the current process. + seed : int + A random seed used determine the worker seed. + """ + worker_seed = num_workers * rank + worker_id + seed + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + +def setup() -> None: + """Initialize the process group.""" + dist.init_process_group("nccl") + + +def cleanup() -> None: + """Clean up the process group after training.""" + dist.destroy_process_group() + + +def main(): + args = parser.parse_args() + current_time = datetime.now().strftime("%Y-%m-%d-%H-%M") + checkpoint_subdir = ( + f"{args.experiment}_{current_time}" if args.experiment else f"{current_time}" + ) + args.checkpoint_dir = os.path.join(args.checkpoint_dir, checkpoint_subdir) + os.makedirs(args.checkpoint_dir, exist_ok=True) + + print(args) + + # torch.multiprocessing.set_start_method("spawn") + # torch.multiprocessing.set_start_method("spawn") + if args.distributed_mode: + # dist_utils.init_distributed_mode( + # launcher=args.distributed_launcher, + # backend=args.distributed_backend, + # ) + setup() + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + torch.cuda.empty_cache() + device_id = torch.cuda.current_device() + else: + device_id = None + + # Data loading. + if args.use_synthetic_data: + print( + f"Using synthetic data for training at {args.synthetic_data_dir} between indices {args.synthetic_index_min} and {args.synthetic_index_max}." + ) + train_dataset = loader.ImageNetSynthetic( + args.data_dir, + args.synthetic_data_dir, + index_min=args.synthetic_index_min, + index_max=args.synthetic_index_max, + generative_augmentation_prob=args.generative_augmentation_prob, + ) + else: + print(f"Using real data for training at {args.data_dir}.") + train_data_dir = os.path.join(args.data_dir, "train") + train_dataset = datasets.ImageFolder(train_data_dir, loader.TwoCropsTransform()) + + train_sampler = None + if dist_utils.is_dist_avail_and_initialized() and args.distributed_mode: + train_sampler = DistributedSampler( + train_dataset, + seed=args.seed, + drop_last=True, + ) + init_fn = partial( + worker_init_fn, + num_workers=args.num_workers, + rank=dist_utils.get_rank(), + seed=args.seed, + ) + + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.batch_size, + shuffle=(train_sampler is None), + sampler=train_sampler, + num_workers=args.num_workers, + worker_init_fn=init_fn, + pin_memory=False, + drop_last=True, + ) + if dist_utils.get_rank() == 0: + print(f"Creating model {args.arch}") + model = builder.SimSiam(models.__dict__[args.arch], args.dim, args.pred_dim) + + if args.distributed_mode and dist_utils.is_dist_avail_and_initialized(): + # Apply SyncBN + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + # set the single device scope, otherwise DistributedDataParallel will + # use all available devices + # torch.cuda.set_device(device_id) + model = model.cuda(device_id) + model = DDP(model, device_ids=[device_id]) + else: + raise NotImplementedError("Only DistributedDataParallel is supported.") + if dist_utils.get_rank() == 0: + print(model) # print model after SyncBatchNorm + + # define loss function (criterion) and optimizer + criterion = nn.CosineSimilarity(dim=1).cuda(device_id) + + if args.fix_pred_lr: + optim_params = [ + {"params": model.module.encoder.parameters(), "fix_lr": False}, + {"params": model.module.predictor.parameters(), "fix_lr": True}, + ] + else: + optim_params = model.parameters() + + # infer learning rate before changing batch size + # init_lr = args.lr * args.batch_size / 256.0 + # TODO(arashaf): Hard-code init-lr to match the original paper with bs=512. + init_lr = args.lr * 2.0 + + optimizer = torch.optim.SGD( + optim_params, + init_lr, + momentum=args.momentum, + weight_decay=args.weight_decay, + ) + + start_epoch = 0 + # Optionally resume from a checkpoint + if args.resume_from_checkpoint: + if os.path.isfile(args.resume_from_checkpoint): + print(f"Loading checkpoint: {args.resume_from_checkpoint}") + checkpoint = torch.load(args.resume_from_checkpoint) + start_epoch = checkpoint["epoch"] + 1 + model.load_state_dict(checkpoint["state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + print(f"Loaded checkpoint {args.resume_from_checkpoint} successfully.") + else: + raise ValueError(f"No checkpoint found at: {args.resume_from_checkpoint}") + + cudnn.benchmark = True + + for epoch in range(start_epoch, args.epochs): + print(f"Starting training epoch: {epoch}") + if dist_utils.is_dist_avail_and_initialized(): + train_sampler.set_epoch(epoch) + adjust_learning_rate(optimizer, init_lr, epoch, args) + + # train for one epoch + train(train_loader, model, criterion, optimizer, epoch, device_id, args) + + # Checkpointing. + if dist_utils.get_rank() == 0: + checkpoint_name = "checkpoint_{:04d}.pth.tar".format(epoch) + checkpoint_file = os.path.join(args.checkpoint_dir, checkpoint_name) + save_checkpoint( + { + "epoch": epoch, + "arch": args.arch, + "state_dict": model.state_dict(), + "optimizer": optimizer.state_dict(), + }, + filename=checkpoint_file, + ) + + +def train(train_loader, model, criterion, optimizer, epoch, device_id, args): + """Single epoch training code.""" + losses = AverageMeter("Loss", ":.4f") + progress = ProgressMeter( + len(train_loader), + [losses], + prefix="Epoch: [{}]".format(epoch), + ) + + # switch to train mode + model.train() + + for i, (images, _) in enumerate(train_loader): + # for images, _ in tqdm(train_loader): + images[0] = images[0].cuda(device_id, non_blocking=True) + images[1] = images[1].cuda(device_id, non_blocking=True) + + # compute output and loss + p1, p2, z1, z2 = model(x1=images[0], x2=images[1]) + loss = -(criterion(p1, z2).mean() + criterion(p2, z1).mean()) * 0.5 + + losses.update(loss.item(), images[0].size(0)) + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if i % args.print_freq == 0: + progress.display(i) + + +def save_checkpoint(state, filename="checkpoint.pth.tar"): + """Save state dictionary into a model checkpoint.""" + print(f"Saving checkpoint at: {filename}") + torch.save(state, filename) + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=":f"): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" + return fmtstr.format(**self.__dict__) + + +class ProgressMeter(object): + def __init__(self, num_batches, meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def display(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print("\t".join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = "{:" + str(num_digits) + "d}" + return "[" + fmt + "/" + fmt.format(num_batches) + "]" + + +def adjust_learning_rate(optimizer, init_lr, epoch, args): + """Decay the learning rate based on schedule.""" + cur_lr = init_lr * 0.5 * (1.0 + math.cos(math.pi * epoch / args.epochs)) + for param_group in optimizer.param_groups: + if "fix_lr" in param_group and param_group["fix_lr"]: + param_group["lr"] = init_lr + else: + param_group["lr"] = cur_lr + + +if __name__ == "__main__": + main() diff --git a/simsiam/builder.py b/simsiam/builder.py new file mode 100644 index 0000000..7ca8c50 --- /dev/null +++ b/simsiam/builder.py @@ -0,0 +1,61 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + + +class SimSiam(nn.Module): + """ + Build a SimSiam model. + """ + def __init__(self, base_encoder, dim=2048, pred_dim=512): + """ + dim: feature dimension (default: 2048) + pred_dim: hidden dimension of the predictor (default: 512) + """ + super(SimSiam, self).__init__() + + # create the encoder + # num_classes is the output fc dimension, zero-initialize last BNs + self.encoder = base_encoder(num_classes=dim, zero_init_residual=True) + + # build a 3-layer projector + prev_dim = self.encoder.fc.weight.shape[1] + self.encoder.fc = nn.Sequential(nn.Linear(prev_dim, prev_dim, bias=False), + nn.BatchNorm1d(prev_dim), + nn.ReLU(inplace=True), # first layer + nn.Linear(prev_dim, prev_dim, bias=False), + nn.BatchNorm1d(prev_dim), + nn.ReLU(inplace=True), # second layer + self.encoder.fc, + nn.BatchNorm1d(dim, affine=False)) # output layer + self.encoder.fc[6].bias.requires_grad = False # hack: not use bias as it is followed by BN + + # build a 2-layer predictor + self.predictor = nn.Sequential(nn.Linear(dim, pred_dim, bias=False), + nn.BatchNorm1d(pred_dim), + nn.ReLU(inplace=True), # hidden layer + nn.Linear(pred_dim, dim)) # output layer + + def forward(self, x1, x2): + """ + Input: + x1: first views of images + x2: second views of images + Output: + p1, p2, z1, z2: predictors and targets of the network + See Sec. 3 of https://arxiv.org/abs/2011.10566 for detailed notations + """ + + # compute features for one view + z1 = self.encoder(x1) # NxC + z2 = self.encoder(x2) # NxC + + p1 = self.predictor(z1) # NxC + p2 = self.predictor(z2) # NxC + + return p1, p2, z1.detach(), z2.detach() diff --git a/simsiam/inatural_dataset.py b/simsiam/inatural_dataset.py new file mode 100644 index 0000000..42a46f1 --- /dev/null +++ b/simsiam/inatural_dataset.py @@ -0,0 +1,77 @@ +import torch.utils.data as data +from PIL import Image +import os +import json +from torchvision import transforms +import random +import numpy as np + + +def default_loader(path): + return Image.open(path).convert('RGB') + +def load_taxonomy(ann_data, tax_levels, classes): + # loads the taxonomy data and converts to ints + taxonomy = {} + + if 'categories' in ann_data.keys(): + num_classes = len(ann_data['categories']) + for tt in tax_levels: + tax_data = [aa[tt] for aa in ann_data['categories']] + _, tax_id = np.unique(tax_data, return_inverse=True) + taxonomy[tt] = dict(zip(range(num_classes), list(tax_id))) + else: + # set up dummy data + for tt in tax_levels: + taxonomy[tt] = dict(zip([0], [0])) + + # create a dictionary of lists containing taxonomic labels + classes_taxonomic = {} + for cc in np.unique(classes): + tax_ids = [0]*len(tax_levels) + for ii, tt in enumerate(tax_levels): + tax_ids[ii] = taxonomy[tt][cc] + classes_taxonomic[cc] = tax_ids + + return taxonomy, classes_taxonomic + + +class INAT(data.Dataset): + def __init__(self, root, ann_file, transform): + + # load annotations + print('Loading annotations from: ' + os.path.basename(ann_file)) + with open(ann_file) as data_file: + ann_data = json.load(data_file) + + # set up the filenames and annotations + self.imgs = [aa['file_name'] for aa in ann_data['images']] + self.ids = [aa['id'] for aa in ann_data['images']] + + # if we dont have class labels set them to '0' + if 'annotations' in ann_data.keys(): + self.classes = [aa['category_id'] for aa in ann_data['annotations']] + else: + self.classes = [0]*len(self.imgs) + + # print out some stats + print('\t' + str(len(self.imgs)) + ' images') + print('\t' + str(len(set(self.classes))) + ' classes') + + self.root = root + self.loader = default_loader + + # augmentation params + self.transform = transform + + def __getitem__(self, index): + path = self.root + self.imgs[index] + img = self.loader(path) + species_id = self.classes[index] + + img = self.transform(img) + + return img, species_id + + def __len__(self): + return len(self.imgs) \ No newline at end of file diff --git a/simsiam/linear_eval.py b/simsiam/linear_eval.py new file mode 100755 index 0000000..9bf31a4 --- /dev/null +++ b/simsiam/linear_eval.py @@ -0,0 +1,449 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import math +import os +import random +import shutil +from functools import partial + +import torch +import torch.nn.parallel +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +from torch import nn +from torch.backends import cudnn +from torch.nn.parallel import DistributedDataParallel as DDP # noqa: N817 +from torchvision import datasets, models, transforms +from tqdm import tqdm + +from SimCLR import distributed as dist_utils + + +model_names = sorted( + name + for name in models.__dict__ + if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) +) + +parser = argparse.ArgumentParser(description="PyTorch ImageNet Training") +parser.add_argument( + "--data_dir", + metavar="DIR", + default="/scratch/ssd004/datasets/imagenet256", + help="path to dataset.", +) +parser.add_argument( + "-a", + "--arch", + metavar="ARCH", + default="resnet50", + choices=model_names, + help="model architecture: " + " | ".join(model_names) + " (default: resnet50)", +) +parser.add_argument( + "-j", + "--num_workers", + default=4, + type=int, + metavar="N", + help="number of data loading workers (default: 32)", +) +parser.add_argument( + "--epochs", default=90, type=int, metavar="N", help="number of total epochs to run" +) +parser.add_argument( + "-b", + "--batch-size", + default=4096, + type=int, + metavar="N", + help="mini-batch size (default: 4096), this is the total " + "batch size of all GPUs on the current node when " + "using Data Parallel or Distributed Data Parallel", +) +parser.add_argument( + "--lr", + "--learning-rate", + default=0.1, + type=float, + metavar="LR", + help="initial (base) learning rate", + dest="lr", +) +parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") +parser.add_argument( + "--wd", + "--weight-decay", + default=0.0, + type=float, + metavar="W", + help="weight decay (default: 0.)", + dest="weight_decay", +) +parser.add_argument( + "--distributed_mode", + action="store_true", + help="Enable distributed training", +) +parser.add_argument("--distributed_launcher", default="slurm") +parser.add_argument("--distributed_backend", default="nccl") +parser.add_argument( + "--seed", default=42, type=int, help="seed for initializing training. " +) +parser.add_argument( + "--pretrained_checkpoint", + default="", + type=str, + help="Path to simsiam pretrained checkpoint.", +) +parser.add_argument("--lars", action="store_true", help="Use LARS") +parser.add_argument( + "--checkpoint_dir", + default="", + help="Checkpoint directory to save eval model checkpoints.", +) + + +best_acc1 = 0 + + +def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int) -> None: + """Initialize worker processes with a random seed. + Parameters + ---------- + worker_id : int + ID of the worker process. + num_workers : int + Total number of workers that will be initialized. + rank : int + The rank of the current process. + seed : int + A random seed used determine the worker seed. + """ + worker_seed = num_workers * rank + worker_id + seed + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + +def main(): + args = parser.parse_args() + global best_acc1 + + torch.multiprocessing.set_start_method("spawn") + if args.distributed_mode: + dist_utils.init_distributed_mode( + launcher=args.distributed_launcher, + backend=args.distributed_backend, + ) + device_id = torch.cuda.current_device() + else: + device_id = None + + # create model + print(f"Creating model {args.arch}") + model = models.__dict__[args.arch]() + + # freeze all layers but the last fc + for name, param in model.named_parameters(): + if name not in ["fc.weight", "fc.bias"]: + param.requires_grad = False + # init the fc layer + model.fc.weight.data.normal_(mean=0.0, std=0.01) + model.fc.bias.data.zero_() + + # load from pre-trained, before DistributedDataParallel constructor + if args.pretrained_checkpoint: + if os.path.isfile(args.pretrained_checkpoint): + print(f"Loading checkpoint {args.pretrained_checkpoint}") + checkpoint = torch.load(args.pretrained_checkpoint, map_location="cpu") + + # rename moco pre-trained keys + state_dict = checkpoint["state_dict"] + for k in list(state_dict.keys()): + # retain only encoder up to before the embedding layer + if k.startswith("module.encoder") and not k.startswith( + "module.encoder.fc" + ): + # remove prefix + state_dict[k[len("module.encoder.") :]] = state_dict[k] + # delete renamed or unused k + del state_dict[k] + + msg = model.load_state_dict(state_dict, strict=False) + assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} + else: + raise ValueError(f"No checkpoint found at: {args.pretrained_checkpoint}") + + # infer learning rate before changing batch size + init_lr = args.lr * args.batch_size * 4.0 / 256.0 + + if args.distributed_mode and dist_utils.is_dist_avail_and_initialized(): + torch.cuda.set_device(device_id) + model = model.cuda(device_id) + model = DDP(model, device_ids=[device_id]) + else: + raise NotImplementedError("Only DistributedDataParallel is supported.") + + # define loss function (criterion) and optimizer + criterion = nn.CrossEntropyLoss().cuda(device_id) + + # optimize only the linear classifier + parameters = list(filter(lambda p: p.requires_grad, model.parameters())) + assert len(parameters) == 2 # fc.weight, fc.bias + + optimizer = torch.optim.SGD( + parameters, init_lr, momentum=args.momentum, weight_decay=args.weight_decay + ) + if args.lars: + print("Use LARS optimizer.") + from LARC import LARC + + optimizer = LARC(optimizer=optimizer, trust_coefficient=0.001, clip=False) + + cudnn.benchmark = True + + # Data loading code + train_dir = os.path.join(args.data_dir, "train") + val_dir = os.path.join(args.data_dir, "val") + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + train_dataset = datasets.ImageFolder( + train_dir, + transforms.Compose( + [ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ], + ), + ) + + if dist_utils.is_dist_avail_and_initialized() and args.distributed_mode: + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, + seed=args.seed, + ) + else: + train_sampler = None + + init_fn = partial( + worker_init_fn, + num_workers=args.num_workers, + rank=dist_utils.get_rank(), + seed=args.seed, + ) + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.batch_size, + shuffle=(train_sampler is None), + sampler=train_sampler, + num_workers=args.num_workers, + worker_init_fn=init_fn, + pin_memory=True, # TODO(arashaf): this was set to false in training script. + ) + + val_loader = torch.utils.data.DataLoader( + datasets.ImageFolder( + val_dir, + transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ], + ), + ), + batch_size=256, + shuffle=False, + num_workers=args.num_workers, + pin_memory=True, + ) + + for epoch in tqdm(range(args.epochs)): + print(f"Starting training epoch: {epoch}") + if dist_utils.is_dist_avail_and_initialized(): + train_sampler.set_epoch(epoch) + adjust_learning_rate(optimizer, init_lr, epoch, args) + + # train for one epoch + train(train_loader, model, criterion, optimizer, epoch, device_id, args) + + # evaluate on validation set + acc1 = validate(val_loader, model, criterion, device_id, args) + + # remember best acc@1 and save checkpoint + is_best = acc1 > best_acc1 + best_acc1 = max(acc1, best_acc1) + + if args.checkpoint_dir and dist_utils.get_rank() == 0: + os.makedirs(args.checkpoint_dir, exist_ok=True) + checkpoint_name = "eval_checkpoint_{:04d}.pth.tar".format(epoch) + checkpoint_file = os.path.join(args.checkpoint_dir, checkpoint_name) + save_checkpoint( + { + "epoch": epoch, + "arch": args.arch, + "state_dict": model.state_dict(), + "best_acc1": best_acc1, + "optimizer": optimizer.state_dict(), + }, + is_best, + checkpoint_file, + ) + if epoch == 0: + sanity_check(model.state_dict(), args.pretrained_checkpoint) + + +def train(train_loader, model, criterion, optimizer, epoch, device_id, args): + """ + Switch to eval mode: + Under the protocol of linear classification on frozen features/models, + it is not legitimate to change any part of the pre-trained model. + BatchNorm in train mode may revise running mean/std (even if it receives + no gradient), which are part of the model parameters too. + """ + model.eval() + + for images, target in tqdm(train_loader): + images = images.cuda(device_id, non_blocking=True) + target = target.cuda(device_id, non_blocking=True) + + # compute output + output = model(images) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + optimizer.step() + + +def validate(val_loader, model, criterion, device_id, args): + top1 = AverageMeter("Acc@1", ":6.2f") + top5 = AverageMeter("Acc@5", ":6.2f") + + # switch to evaluate mode + model.eval() + + with torch.no_grad(): + for images, target in tqdm(val_loader): + images = images.cuda(device_id, non_blocking=True) + target = target.cuda(device_id, non_blocking=True) + + # compute output + output = model(images) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + top1.update(acc1[0], images.size(0)) + top5.update(acc5[0], images.size(0)) + + print( + "Validation Accuracy@1 {top1.avg:.3f}, Accuracy@5 {top5.avg:.3f}".format( + top1=top1, top5=top5 + ) + ) + + return top1.avg + + +def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"): + print(f"Saving checkpoint at: {filename}") + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, "model_best.pth.tar") + + +def sanity_check(state_dict, pretrained_weights): + """ + Linear classifier should not change any weights other than the linear layer. + This sanity check asserts nothing wrong happens (e.g., BN stats updated). + """ + print(f"Loading {pretrained_weights} for sanity check") + checkpoint = torch.load(pretrained_weights, map_location="cpu") + state_dict_pre = checkpoint["state_dict"] + + for k in list(state_dict.keys()): + # only ignore fc layer + if "fc.weight" in k or "fc.bias" in k: + continue + + # name in pretrained model + k_pre = ( + "module.encoder." + k[len("module.") :] + if k.startswith("module.") + else "module.encoder." + k + ) + + assert ( + state_dict[k].cpu() == state_dict_pre[k_pre] + ).all(), "{} is changed in linear classifier training.".format(k) + + print("Sanity check passed.") + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=":f"): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" + return fmtstr.format(**self.__dict__) + + +def adjust_learning_rate(optimizer, init_lr, epoch, args): + """Decay the learning rate based on schedule""" + cur_lr = init_lr * 0.5 * (1.0 + math.cos(math.pi * epoch / args.epochs)) + for param_group in optimizer.param_groups: + param_group["lr"] = cur_lr + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +if __name__ == "__main__": + main() diff --git a/simsiam/linear_eval_downstream_datasets.py b/simsiam/linear_eval_downstream_datasets.py new file mode 100644 index 0000000..616b145 --- /dev/null +++ b/simsiam/linear_eval_downstream_datasets.py @@ -0,0 +1,498 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import math +import os +import random +import shutil +from functools import partial + +import torch +import torch.nn.parallel +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +from torch import nn +from torch.backends import cudnn +from torch.nn.parallel import DistributedDataParallel as DDP # noqa: N817 +from torchvision import datasets, models, transforms +from tqdm import tqdm + +from SimCLR import distributed as dist_utils + + +model_names = sorted( + name + for name in models.__dict__ + if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) +) + +parser = argparse.ArgumentParser(description="PyTorch ImageNet Training") +parser.add_argument( + "--data_dir", + metavar="DIR", + default="/scratch/ssd004/datasets/imagenet256", + help="path to dataset.", +) +parser.add_argument( + "-a", + "--arch", + metavar="ARCH", + default="resnet50", + choices=model_names, + help="model architecture: " + " | ".join(model_names) + " (default: resnet50)", +) +parser.add_argument( + "-j", + "--num_workers", + default=4, + type=int, + metavar="N", + help="number of data loading workers (default: 32)", +) +parser.add_argument( + "--epochs", default=90, type=int, metavar="N", help="number of total epochs to run" +) +parser.add_argument( + "-b", + "--batch-size", + default=4096, + type=int, + metavar="N", + help="mini-batch size (default: 4096), this is the total " + "batch size of all GPUs on the current node when " + "using Data Parallel or Distributed Data Parallel", +) +parser.add_argument( + "--lr", + "--learning-rate", + default=0.1, + type=float, + metavar="LR", + help="initial (base) learning rate", + dest="lr", +) +parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") +parser.add_argument( + "--wd", + "--weight-decay", + default=0.0, + type=float, + metavar="W", + help="weight decay (default: 0.)", + dest="weight_decay", +) +parser.add_argument( + "--distributed_mode", + action="store_true", + help="Enable distributed training", +) +parser.add_argument("--distributed_launcher", default="slurm") +parser.add_argument("--distributed_backend", default="nccl") +parser.add_argument( + "--seed", default=42, type=int, help="seed for initializing training. " +) +parser.add_argument( + "--pretrained_checkpoint", + default="", + type=str, + help="Path to simsiam pretrained checkpoint.", +) +parser.add_argument("--lars", action="store_true", help="Use LARS") +parser.add_argument( + "--checkpoint_dir", + default="", + help="Checkpoint directory to save eval model checkpoints.", +) +parser.add_argument("--dataset_name", default="imagenet", help="Name of the dataset.") + +best_acc1 = 0 + + +def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int) -> None: + """Initialize worker processes with a random seed. + + Parameters + ---------- + worker_id : int + ID of the worker process. + num_workers : int + Total number of workers that will be initialized. + rank : int + The rank of the current process. + seed : int + A random seed used determine the worker seed. + """ + worker_seed = num_workers * rank + worker_id + seed + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + +def main(): + args = parser.parse_args() + global best_acc1 + + torch.multiprocessing.set_start_method("spawn") + if args.distributed_mode: + dist_utils.init_distributed_mode( + launcher=args.distributed_launcher, + backend=args.distributed_backend, + ) + device_id = torch.cuda.current_device() + else: + device_id = None + + # create model + print(f"Creating model {args.arch}") + model = models.__dict__[args.arch]() + + # freeze all layers but the last fc + for name, param in model.named_parameters(): + if name not in ["fc.weight", "fc.bias"]: + param.requires_grad = False + # init the fc layer + model.fc.weight.data.normal_(mean=0.0, std=0.01) + model.fc.bias.data.zero_() + + # load from pre-trained, before DistributedDataParallel constructor + if args.pretrained_checkpoint: + if os.path.isfile(args.pretrained_checkpoint): + print(f"Loading checkpoint {args.pretrained_checkpoint}") + checkpoint = torch.load(args.pretrained_checkpoint, map_location="cpu") + + # rename moco pre-trained keys + state_dict = checkpoint["state_dict"] + for k in list(state_dict.keys()): + # retain only encoder up to before the embedding layer + if k.startswith("module.encoder") and not k.startswith( + "module.encoder.fc" + ): + # remove prefix + state_dict[k[len("module.encoder.") :]] = state_dict[k] + # delete renamed or unused k + del state_dict[k] + + msg = model.load_state_dict(state_dict, strict=False) + assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} + else: + raise ValueError(f"No checkpoint found at: {args.pretrained_checkpoint}") + + # infer learning rate before changing batch size + init_lr = args.lr * args.batch_size * 4 / 256 + + if args.distributed_mode and dist_utils.is_dist_avail_and_initialized(): + torch.cuda.set_device(device_id) + model = model.cuda(device_id) + model = DDP(model, device_ids=[device_id]) + else: + raise NotImplementedError("Only DistributedDataParallel is supported.") + + # define loss function (criterion) and optimizer + criterion = nn.CrossEntropyLoss().cuda(device_id) + + # optimize only the linear classifier + parameters = list(filter(lambda p: p.requires_grad, model.parameters())) + assert len(parameters) == 2 # fc.weight, fc.bias + + optimizer = torch.optim.SGD( + parameters, init_lr, momentum=args.momentum, weight_decay=args.weight_decay + ) + if args.lars: + print("Use LARS optimizer.") + # from apex.parallel.LARC import LARC + from LARC import LARC + + optimizer = LARC(optimizer=optimizer, trust_coefficient=0.001, clip=False) + + cudnn.benchmark = True + + # Data loading code + train_dir = os.path.join(args.data_dir, "train") + val_dir = os.path.join(args.data_dir, "val") + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + if args.dataset_name == "imagenet": + train_dataset = datasets.ImageFolder( + train_dir, + transforms.Compose( + [ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ], + ), + ) + val_dataset = datasets.ImageFolder( + val_dir, + transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ], + ), + ) + elif args.dataset_name == "food101": + train_dataset=datasets.Food101( + root=args.data_dir, + split="train", + transform=transforms.Compose( + [ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ], + ),) + val_dataset=datasets.Food101( + root=args.data_dir, + split="test", + transform=transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ], + ),) + elif args.dataset_name == "places365": + train_dataset=datasets.Places365( + root=args.data_dir, + split="train-standard", + transform=transforms.Compose( + [ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ], + ),) + val_dataset=datasets.Places365( + root=args.data_dir, + split="val", + transform=transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ], + ),) + + if dist_utils.is_dist_avail_and_initialized() and args.distributed_mode: + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, + seed=args.seed, + ) + else: + train_sampler = None + + init_fn = partial( + worker_init_fn, + num_workers=args.num_workers, + rank=dist_utils.get_rank(), + seed=args.seed, + ) + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.batch_size, + shuffle=(train_sampler is None), + sampler=train_sampler, + num_workers=args.num_workers, + worker_init_fn=init_fn, + pin_memory=True, # TODO(arashaf): this was set to false in training script. + ) + + val_loader = torch.utils.data.DataLoader( + val_dataset, + batch_size=256, + shuffle=False, + num_workers=args.num_workers, + pin_memory=True, + ) + + for epoch in tqdm(range(args.epochs)): + print(f"Starting training epoch: {epoch}") + if dist_utils.is_dist_avail_and_initialized(): + train_sampler.set_epoch(epoch) + adjust_learning_rate(optimizer, init_lr, epoch, args) + + # train for one epoch + train(train_loader, model, criterion, optimizer, epoch, device_id, args) + + # evaluate on validation set + acc1 = validate(val_loader, model, criterion, device_id, args) + + # remember best acc@1 and save checkpoint + is_best = acc1 > best_acc1 + best_acc1 = max(acc1, best_acc1) + + if args.checkpoint_dir and dist_utils.get_rank() == 0: + os.makedirs(args.checkpoint_dir, exist_ok=True) + checkpoint_name = "eval_checkpoint_{:04d}.pth.tar".format(epoch) + checkpoint_file = os.path.join(args.checkpoint_dir, checkpoint_name) + save_checkpoint( + { + "epoch": epoch, + "arch": args.arch, + "state_dict": model.state_dict(), + "best_acc1": best_acc1, + "optimizer": optimizer.state_dict(), + }, + is_best, + checkpoint_file, + ) + if epoch == 0: + sanity_check(model.state_dict(), args.pretrained_checkpoint) + + +def train(train_loader, model, criterion, optimizer, epoch, device_id, args): + """ + Switch to eval mode: + Under the protocol of linear classification on frozen features/models, + it is not legitimate to change any part of the pre-trained model. + BatchNorm in train mode may revise running mean/std (even if it receives + no gradient), which are part of the model parameters too. + """ + model.eval() + + for images, target in tqdm(train_loader): + images = images.cuda(device_id, non_blocking=True) + target = target.cuda(device_id, non_blocking=True) + + # compute output + output = model(images) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + optimizer.step() + + +def validate(val_loader, model, criterion, device_id, args): + top1 = AverageMeter("Acc@1", ":6.2f") + top5 = AverageMeter("Acc@5", ":6.2f") + + # switch to evaluate mode + model.eval() + + with torch.no_grad(): + for images, target in tqdm(val_loader): + images = images.cuda(device_id, non_blocking=True) + target = target.cuda(device_id, non_blocking=True) + + # compute output + output = model(images) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + top1.update(acc1[0], images.size(0)) + top5.update(acc5[0], images.size(0)) + + print( + "Validation Accuracy@1 {top1.avg:.3f}, Accuracy@5 {top5.avg:.3f}".format( + top1=top1, top5=top5 + ) + ) + + return top1.avg + + +def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"): + print(f"Saving checkpoint at: {filename}") + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, "model_best.pth.tar") + + +def sanity_check(state_dict, pretrained_weights): + """ + Linear classifier should not change any weights other than the linear layer. + This sanity check asserts nothing wrong happens (e.g., BN stats updated). + """ + print(f"Loading {pretrained_weights} for sanity check") + checkpoint = torch.load(pretrained_weights, map_location="cpu") + state_dict_pre = checkpoint["state_dict"] + + for k in list(state_dict.keys()): + # only ignore fc layer + if "fc.weight" in k or "fc.bias" in k: + continue + + # name in pretrained model + k_pre = ( + "module.encoder." + k[len("module.") :] + if k.startswith("module.") + else "module.encoder." + k + ) + + assert ( + state_dict[k].cpu() == state_dict_pre[k_pre] + ).all(), "{} is changed in linear classifier training.".format(k) + + print("Sanity check passed.") + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=":f"): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" + return fmtstr.format(**self.__dict__) + + +def adjust_learning_rate(optimizer, init_lr, epoch, args): + """Decay the learning rate based on schedule""" + cur_lr = init_lr * 0.5 * (1.0 + math.cos(math.pi * epoch / args.epochs)) + for param_group in optimizer.param_groups: + param_group["lr"] = cur_lr + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +if __name__ == "__main__": + main() diff --git a/simsiam/linear_eval_original_code.py b/simsiam/linear_eval_original_code.py new file mode 100644 index 0000000..5d85f85 --- /dev/null +++ b/simsiam/linear_eval_original_code.py @@ -0,0 +1,848 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import builtins +import math +import os +import random +import shutil +import time +import warnings +from datetime import datetime + +import torch +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +import torch.nn.parallel +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +import torchvision.datasets as datasets +import torchvision.models as models +import torchvision.transforms as transforms +from tqdm import tqdm +from icgan.data_utils import utils as data_utils + +from inatural_dataset import INAT + + +model_names = sorted( + name + for name in models.__dict__ + if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) +) + +parser = argparse.ArgumentParser(description="PyTorch ImageNet Training") +parser.add_argument( + "--data", + metavar="DIR", + default="/scratch/ssd004/datasets/imagenet256", + help="path to dataset.", +) +parser.add_argument( + "-a", + "--arch", + metavar="ARCH", + default="resnet50", + choices=model_names, + help="model architecture: " + " | ".join(model_names) + " (default: resnet50)", +) +parser.add_argument( + "-j", + "--workers", + default=4, + type=int, + metavar="N", + help="number of data loading workers (default: 32)", +) +parser.add_argument( + "--epochs", default=90, type=int, metavar="N", help="number of total epochs to run" +) +parser.add_argument( + "-b", + "--batch-size", + default=4096, + type=int, + metavar="N", + help="mini-batch size (default: 4096), this is the total " + "batch size of all GPUs on the current node when " + "using Data Parallel or Distributed Data Parallel", +) +parser.add_argument( + "--lr", + "--learning-rate", + default=0.1, + type=float, + metavar="LR", + help="initial (base) learning rate", + dest="lr", +) +parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") +parser.add_argument( + "--wd", + "--weight-decay", + default=0.0, + type=float, + metavar="W", + help="weight decay (default: 0.)", + dest="weight_decay", +) +parser.add_argument( + "-p", + "--print-freq", + default=10, + type=int, + metavar="N", + help="print frequency (default: 10)", +) +parser.add_argument( + "--resume", + default="", + type=str, + metavar="PATH", + help="path to latest checkpoint (default: none)", +) +parser.add_argument( + "-e", + "--evaluate", + dest="evaluate", + action="store_true", + help="evaluate model on validation set", +) +parser.add_argument( + "--world-size", + default=-1, + type=int, + help="number of nodes for distributed training", +) +parser.add_argument( + "--rank", default=-1, type=int, help="node rank for distributed training" +) +parser.add_argument( + "--dist-url", + default="tcp://224.66.41.62:23456", + type=str, + help="url used to set up distributed training", +) +parser.add_argument( + "--dist-backend", default="nccl", type=str, help="distributed backend" +) +parser.add_argument( + "--seed", default=None, type=int, help="seed for initializing training. " +) +parser.add_argument("--gpu", default=None, type=int, help="GPU id to use.") +parser.add_argument( + "--multiprocessing-distributed", + action="store_true", + help="Use multi-processing distributed training to launch " + "N processes per node, which has N GPUs. This is the " + "fastest way to use PyTorch for either single node or " + "multi node data parallel training", +) + +# additional configs: +parser.add_argument( + "--pretrained", default="", type=str, help="path to simsiam pretrained checkpoint" +) +parser.add_argument("--lars", action="store_true", help="Use LARS") + +parser.add_argument("--dataset_name", default="imagenet", help="Name of the dataset.") + +parser.add_argument( + "--checkpoint_dir", + default="/projects/imagenet_synthetic/model_checkpoints", + help="Checkpoint root directory.", +) + +parser.add_argument( + "--num_classes", + default=1000, + type=int, + help="Number of classes in the dataset.", +) + +parser.add_argument( + "--ablation_mode", + default="icgan", + type=str, + help="Using icgan or stable diffusion feature extractor for ablation study.", +) + +best_acc1 = 0 + + +def main(): + args = parser.parse_args() + current_time = datetime.now().strftime("%Y-%m-%d-%H-%M") + args.checkpoint_dir = os.path.join(args.checkpoint_dir, f"eval_{current_time}") + os.makedirs(args.checkpoint_dir, exist_ok=True) + + print(args) + + if args.seed is not None: + random.seed(args.seed) + torch.manual_seed(args.seed) + # NOTE: this line can reduce speed considerably + # cudnn.deterministic = True + warnings.warn( + "You have chosen to seed training. " + "This will turn on the CUDNN deterministic setting, " + "which can slow down your training considerably! " + "You may see unexpected behavior when restarting " + "from checkpoints." + ) + + if args.gpu is not None: + warnings.warn( + "You have chosen a specific GPU. This will completely " + "disable data parallelism." + ) + + if args.dist_url == "env://" and args.world_size == -1: + args.world_size = int(os.environ["WORLD_SIZE"]) + print(args.world_size) + + args.distributed = args.world_size > 1 or args.multiprocessing_distributed + + ngpus_per_node = torch.cuda.device_count() + if args.multiprocessing_distributed: + # Since we have ngpus_per_node processes per node, the total world_size + # needs to be adjusted accordingly + args.world_size = ngpus_per_node * args.world_size + print("second", args.world_size) + # Use torch.multiprocessing.spawn to launch distributed processes: the + # main_worker process function + mp.spawn( + main_worker, + nprocs=ngpus_per_node, + args=( + ngpus_per_node, + args, + ), + ) + else: + # Simply call main_worker function + main_worker(args.gpu, ngpus_per_node, args) + + +def main_worker(gpu, ngpus_per_node, args): + global best_acc1 + print("spawn performed, gpu", gpu, flush=True) + args.gpu = gpu + + # suppress printing if not master + if args.multiprocessing_distributed and args.gpu != 0: + + def print_pass(*args, flush=True): + pass + + builtins.print = print_pass + + if args.gpu is not None: + print("Use GPU: {} for training".format(args.gpu), flush=True) + + if args.distributed: + print("here", flush=True) + if args.dist_url == "env://" and args.rank == -1: + args.rank = int(os.environ["RANK"]) + print("rank", args.rank, flush=True) + if args.multiprocessing_distributed: + # For multiprocessing distributed training, rank needs to be the + # global rank among all the processes + args.rank = args.rank * ngpus_per_node + gpu + print("second rank", args.rank, flush=True) + dist.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + print("init_process_group", flush=True) + torch.distributed.barrier() + + + # create model + print("=> creating model '{}'".format(args.arch), flush=True) + model = models.__dict__[args.arch]() + + model.fc = nn.Linear(2048, args.num_classes) + + print("model", model.state_dict().keys(), flush=True) + + # freeze all layers but the last fc + for name, param in model.named_parameters(): + if name not in ["fc.weight", "fc.bias"]: + param.requires_grad = False + # init the fc layer + model.fc.weight.data.normal_(mean=0.0, std=0.01) + model.fc.bias.data.zero_() + + # load from pre-trained, before DistributedDataParallel constructor + if args.pretrained: + if os.path.isfile(args.pretrained): + print("=> loading checkpoint '{}'".format(args.pretrained), flush=True) + checkpoint = torch.load(args.pretrained, map_location="cpu") + + # rename moco pre-trained keys + if args.ablation_mode == "icgan": + state_dict = checkpoint + else: + state_dict = checkpoint["state_dict"] + for k in list(state_dict.keys()): + # retain only encoder up to before the embedding layer + if args.ablation_mode == "icgan": + if k.startswith("module") and not k.startswith( + "module.fc" + ): + # remove prefix + state_dict[k[len("module.") :]] = state_dict[k] + # delete renamed or unused k + del state_dict[k] + else: + if k.startswith("module.encoder") and not k.startswith( + "module.encoder.fc" + ): + # remove prefix + state_dict[k[len("module.encoder.") :]] = state_dict[k] + # delete renamed or unused k + del state_dict[k] + + args.start_epoch = 0 + msg = model.load_state_dict(state_dict, strict=False) + assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} + + print("=> loaded pre-trained model '{}'".format(args.pretrained)) + else: + print("=> no checkpoint found at '{}'".format(args.pretrained)) + + # infer learning rate before changing batch size + init_lr = args.lr * args.batch_size / 256 + + if args.distributed: + # For multiprocessing distributed, DistributedDataParallel constructor + # should always set the single device scope, otherwise, + # DistributedDataParallel will use all available devices. + if args.gpu is not None: + torch.cuda.set_device(args.gpu) + model.cuda(args.gpu) + # When using a single GPU per process and per + # DistributedDataParallel, we need to divide the batch size + # ourselves based on the total number of GPUs we have + args.batch_size = int(args.batch_size / ngpus_per_node) + print("batchsize", args.batch_size, flush=True) + args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) + print("workers", args.workers, flush=True) + print("gpu", args.gpu, flush=True) + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[args.gpu] + ) + else: + model.cuda() + # DistributedDataParallel will divide and allocate batch_size to all + # available GPUs if device_ids are not set + model = torch.nn.parallel.DistributedDataParallel(model) + elif args.gpu is not None: + torch.cuda.set_device(args.gpu) + model = model.cuda(args.gpu) + else: + # DataParallel will divide and allocate batch_size to all available GPUs + if args.arch.startswith("alexnet") or args.arch.startswith("vgg"): + model.features = torch.nn.DataParallel(model.features) + model.cuda() + else: + model = torch.nn.DataParallel(model).cuda() + + # define loss function (criterion) and optimizer + criterion = nn.CrossEntropyLoss().cuda(args.gpu) + + # optimize only the linear classifier + parameters = list(filter(lambda p: p.requires_grad, model.parameters())) + assert len(parameters) == 2 # fc.weight, fc.bias + + optimizer = torch.optim.SGD( + parameters, init_lr, momentum=args.momentum, weight_decay=args.weight_decay + ) + if args.lars: + print("=> use LARS optimizer.", flush=True) + from LARC import LARC + + optimizer = LARC(optimizer=optimizer, trust_coefficient=0.001, clip=False) + + # optionally resume from a checkpoint + if args.resume: + if os.path.isfile(args.resume): + print("=> loading checkpoint '{}'".format(args.resume), flush=True) + if args.gpu is None: + checkpoint = torch.load(args.resume) + else: + # Map model to be loaded to specified single gpu. + loc = "cuda:{}".format(args.gpu) + checkpoint = torch.load(args.resume, map_location=loc) + args.start_epoch = checkpoint["epoch"] + best_acc1 = checkpoint["best_acc1"] + if args.gpu is not None: + # best_acc1 may be from a checkpoint from a different GPU + best_acc1 = best_acc1.to(args.gpu) + model.load_state_dict(checkpoint["state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + print( + "=> loaded checkpoint '{}' (epoch {})".format( + args.resume, checkpoint["epoch"] + ), + flush=True, + ) + else: + print("=> no checkpoint found at '{}'".format(args.resume)) + + cudnn.benchmark = True + + # Data loading code + traindir = os.path.join(args.data, "train") + valdir = os.path.join(args.data, "val") + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + if args.dataset_name == "imagenet": + train_dataset = datasets.ImageFolder( + traindir, + transforms.Compose( + [ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ] + ), + ) + val_dataset = datasets.ImageFolder( + valdir, + transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ] + ), + ) + elif args.dataset_name == "food101": + print("=> using food101 dataset.", flush=True) + train_dataset = datasets.Food101( + root=args.data, + split="train", + transform=transforms.Compose( + [ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ], + ), + ) + val_dataset = datasets.Food101( + root=args.data, + split="test", + transform=transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ], + ), + ) + elif args.dataset_name == "cifar10": + train_dataset = datasets.CIFAR10( + root=args.data, + train=True, + download=True, + transform=transforms.Compose( + [ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ], + ), + ) + val_dataset = datasets.CIFAR10( + root=args.data, + train=False, + download=True, + transform=transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ], + ), + ) + elif args.dataset_name == "cifar100": + train_dataset = datasets.CIFAR100( + root=args.data, + train=True, + transform=transforms.Compose( + [ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ], + ), + ) + val_dataset = datasets.CIFAR100( + root=args.data, + train=False, + transform=transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ], + ), + ) + elif args.dataset_name == "places365": + train_dataset = datasets.Places365( + root=args.data, + split="train-standard", + transform=transforms.Compose( + [ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ], + ), + ) + val_dataset = datasets.Places365( + root=args.data, + split="val", + transform=transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ], + ), + ) + elif args.dataset_name == "INaturalist": + train_dataset = INAT( + root=args.data, + ann_file=os.path.join(args.data, "train2018.json"), + transform=transforms.Compose( + [ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ], + ), + ) + val_dataset = INAT( + root=args.data, + ann_file=os.path.join(args.data, "val2018.json"), + transform=transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ], + ), + ) + + if args.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + else: + train_sampler = None + + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.batch_size, + shuffle=(train_sampler is None), + num_workers=args.workers, + pin_memory=True, + sampler=train_sampler, + ) + + val_loader = torch.utils.data.DataLoader( + val_dataset, + batch_size=256, + shuffle=False, + num_workers=args.workers, + pin_memory=True, + ) + + if args.evaluate: + validate(val_loader, model, criterion, args) + return + + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + train_sampler.set_epoch(epoch) + adjust_learning_rate(optimizer, init_lr, epoch, args) + + print("epoch", epoch, flush=True) + + # train for one epoch + train(train_loader, model, criterion, optimizer, epoch, args) + + # evaluate on validation set + acc1 = validate(val_loader, model, criterion, args) + + # remember best acc@1 and save checkpoint + is_best = acc1 > best_acc1 + best_acc1 = max(acc1, best_acc1) + + if not args.multiprocessing_distributed or ( + args.multiprocessing_distributed and args.rank % ngpus_per_node == 0 + ): + checkpoint_name = "checkpoint_{:04d}.pth.tar".format(epoch + 1) + checkpoint_file = os.path.join(args.checkpoint_dir, checkpoint_name) + save_checkpoint( + { + "epoch": epoch + 1, + "arch": args.arch, + "state_dict": model.state_dict(), + "best_acc1": best_acc1, + "optimizer": optimizer.state_dict(), + }, + is_best, + filename=checkpoint_file, + ) + if epoch == args.start_epoch: + sanity_check(model.state_dict(), args.pretrained, args.ablation_mode) + + +def train(train_loader, model, criterion, optimizer, epoch, args): + batch_time = AverageMeter("Time", ":6.3f") + data_time = AverageMeter("Data", ":6.3f") + losses = AverageMeter("Loss", ":.4e") + top1 = AverageMeter("Acc@1", ":6.2f") + top5 = AverageMeter("Acc@5", ":6.2f") + progress = ProgressMeter( + len(train_loader), + [batch_time, data_time, losses, top1, top5], + prefix="Epoch: [{}]".format(epoch), + ) + + """ + Switch to eval mode: + Under the protocol of linear classification on frozen features/models, + it is not legitimate to change any part of the pre-trained model. + BatchNorm in train mode may revise running mean/std (even if it receives + no gradient), which are part of the model parameters too. + """ + model.eval() + + end = time.time() + i = 0 + for images, target in tqdm(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + if args.gpu is not None: + images = images.cuda(args.gpu, non_blocking=True) + target = target.cuda(args.gpu, non_blocking=True) + + # compute output + output = model(images) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), images.size(0)) + top1.update(acc1[0], images.size(0)) + top5.update(acc5[0], images.size(0)) + + if i == 0: + print("first step passed", flush=True) + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + progress.display(i) + + i += 1 + + +def validate(val_loader, model, criterion, args): + batch_time = AverageMeter("Time", ":6.3f") + losses = AverageMeter("Loss", ":.4e") + top1 = AverageMeter("Acc@1", ":6.2f") + top5 = AverageMeter("Acc@5", ":6.2f") + progress = ProgressMeter( + len(val_loader), [batch_time, losses, top1, top5], prefix="Test: " + ) + + # switch to evaluate mode + model.eval() + + with torch.no_grad(): + end = time.time() + i = 0 + for images, target in tqdm(val_loader): + if args.gpu is not None: + images = images.cuda(args.gpu, non_blocking=True) + target = target.cuda(args.gpu, non_blocking=True) + + # compute output + output = model(images) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), images.size(0)) + top1.update(acc1[0], images.size(0)) + top5.update(acc5[0], images.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + progress.display(i) + + i += 1 + + # # TODO: this should also be done with the ProgressMeter + print( + "\n * Accuracy@1 {top1.avg:.3f} Accuracy@5 {top5.avg:.3f}".format( + top1=top1, top5=top5 + ) + ) + + return top1.avg + + +def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"): + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, "model_best.pth.tar") + + +def sanity_check(state_dict, pretrained_weights, ablation_mode): + """ + Linear classifier should not change any weights other than the linear layer. + This sanity check asserts nothing wrong happens (e.g., BN stats updated). + """ + print("=> loading '{}' for sanity check".format(pretrained_weights)) + checkpoint = torch.load(pretrained_weights, map_location="cpu") + if ablation_mode == "icgan": + state_dict_pre = checkpoint + else: + state_dict_pre = checkpoint["state_dict"] + + for k in list(state_dict.keys()): + # only ignore fc layer + if "fc.weight" in k or "fc.bias" in k: + continue + if ablation_mode == "icgan": + # name in pretrained model + k_pre = ( + "module." + k[len("module.") :] + if k.startswith("module.") + else "module." + k + ) + + else: + # name in pretrained model + k_pre = ( + "module.encoder." + k[len("module.") :] + if k.startswith("module.") + else "module.encoder." + k + ) + + assert ( + state_dict[k].cpu() == state_dict_pre[k_pre] + ).all(), "{} is changed in linear classifier training.".format(k) + + print("=> sanity check passed.") + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=":f"): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" + return fmtstr.format(**self.__dict__) + + +class ProgressMeter(object): + def __init__(self, num_batches, meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def display(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print("\t".join(entries), flush=True) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = "{:" + str(num_digits) + "d}" + return "[" + fmt + "/" + fmt.format(num_batches) + "]" + + +def adjust_learning_rate(optimizer, init_lr, epoch, args): + """Decay the learning rate based on schedule""" + cur_lr = init_lr * 0.5 * (1.0 + math.cos(math.pi * epoch / args.epochs)) + for param_group in optimizer.param_groups: + param_group["lr"] = cur_lr + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +if __name__ == "__main__": + main() diff --git a/simsiam/linear_eval_original_code_clip.py b/simsiam/linear_eval_original_code_clip.py new file mode 100644 index 0000000..e737615 --- /dev/null +++ b/simsiam/linear_eval_original_code_clip.py @@ -0,0 +1,865 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import builtins +import math +import os +import random +import shutil +import time +import warnings +from datetime import datetime + +import torch +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +import torch.nn.parallel +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +import torchvision.datasets as datasets +import torchvision.models as models +import torchvision.transforms as transforms +from tqdm import tqdm +import clip +from inatural_dataset import INAT + + +model_names = sorted( + name + for name in models.__dict__ + if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) +) + +parser = argparse.ArgumentParser(description="PyTorch ImageNet Training") +parser.add_argument( + "--data", + metavar="DIR", + default="/scratch/ssd004/datasets/imagenet256", + help="path to dataset.", +) +# parser.add_argument( +# "-a", +# "--arch", +# metavar="ARCH", +# default="resnet50", +# choices=model_names, +# help="model architecture: " + " | ".join(model_names) + " (default: resnet50)", +# ) +parser.add_argument( + "-j", + "--workers", + default=4, + type=int, + metavar="N", + help="number of data loading workers (default: 32)", +) +parser.add_argument( + "--epochs", default=90, type=int, metavar="N", help="number of total epochs to run" +) +parser.add_argument( + "-b", + "--batch-size", + default=4096, + type=int, + metavar="N", + help="mini-batch size (default: 4096), this is the total " + "batch size of all GPUs on the current node when " + "using Data Parallel or Distributed Data Parallel", +) +parser.add_argument( + "--lr", + "--learning-rate", + default=0.1, + type=float, + metavar="LR", + help="initial (base) learning rate", + dest="lr", +) +parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") +parser.add_argument( + "--wd", + "--weight-decay", + default=0.0, + type=float, + metavar="W", + help="weight decay (default: 0.)", + dest="weight_decay", +) +parser.add_argument( + "-p", + "--print-freq", + default=10, + type=int, + metavar="N", + help="print frequency (default: 10)", +) +parser.add_argument( + "--resume", + default="", + type=str, + metavar="PATH", + help="path to latest checkpoint (default: none)", +) +parser.add_argument( + "-e", + "--evaluate", + dest="evaluate", + action="store_true", + help="evaluate model on validation set", +) +parser.add_argument( + "--world-size", + default=-1, + type=int, + help="number of nodes for distributed training", +) +parser.add_argument( + "--rank", default=-1, type=int, help="node rank for distributed training" +) +parser.add_argument( + "--dist-url", + default="tcp://224.66.41.62:23456", + type=str, + help="url used to set up distributed training", +) +parser.add_argument( + "--dist-backend", default="nccl", type=str, help="distributed backend" +) +parser.add_argument( + "--seed", default=None, type=int, help="seed for initializing training. " +) +parser.add_argument("--gpu", default=None, type=int, help="GPU id to use.") +parser.add_argument( + "--multiprocessing-distributed", + action="store_true", + help="Use multi-processing distributed training to launch " + "N processes per node, which has N GPUs. This is the " + "fastest way to use PyTorch for either single node or " + "multi node data parallel training", +) + +# additional configs: +parser.add_argument( + "--pretrained", default="", type=str, help="path to simsiam pretrained checkpoint" +) +parser.add_argument("--lars", action="store_true", help="Use LARS") + +parser.add_argument("--dataset_name", default="imagenet", help="Name of the dataset.") + +parser.add_argument( + "--checkpoint_dir", + default="/projects/imagenet_synthetic/model_checkpoints", + help="Checkpoint root directory.", +) + +parser.add_argument( + "--num_classes", + default=1000, + type=int, + help="Number of classes in the dataset.", +) + +# parser.add_argument( +# "--ablation_mode", +# default="icgan", +# type=str, +# help="Using icgan or stable diffusion feature extractor for ablation study.", +# ) + +best_acc1 = 0 + +class CLIPClassifier(nn.Module): + def __init__(self, clip_model, num_classes): + super(CLIPClassifier, self).__init__() + self.clip_model = clip_model + + for param in self.clip_model.parameters(): + param.requires_grad = False + + self.linear = nn.Linear(clip_model.visual.output_dim, num_classes) + + def forward(self, x): + with torch.no_grad(): + x = self.clip_model.encode_image(x) + x = x.float() + x = self.linear(x) + return x + +def main(): + args = parser.parse_args() + current_time = datetime.now().strftime("%Y-%m-%d-%H-%M") + args.checkpoint_dir = os.path.join(args.checkpoint_dir, f"eval_{current_time}") + os.makedirs(args.checkpoint_dir, exist_ok=True) + + print(args) + + if args.seed is not None: + random.seed(args.seed) + torch.manual_seed(args.seed) + # NOTE: this line can reduce speed considerably + # cudnn.deterministic = True + warnings.warn( + "You have chosen to seed training. " + "This will turn on the CUDNN deterministic setting, " + "which can slow down your training considerably! " + "You may see unexpected behavior when restarting " + "from checkpoints." + ) + + if args.gpu is not None: + warnings.warn( + "You have chosen a specific GPU. This will completely " + "disable data parallelism." + ) + + if args.dist_url == "env://" and args.world_size == -1: + args.world_size = int(os.environ["WORLD_SIZE"]) + print(args.world_size) + + args.distributed = args.world_size > 1 or args.multiprocessing_distributed + + ngpus_per_node = torch.cuda.device_count() + if args.multiprocessing_distributed: + # Since we have ngpus_per_node processes per node, the total world_size + # needs to be adjusted accordingly + args.world_size = ngpus_per_node * args.world_size + print("second", args.world_size) + # Use torch.multiprocessing.spawn to launch distributed processes: the + # main_worker process function + mp.spawn( + main_worker, + nprocs=ngpus_per_node, + args=( + ngpus_per_node, + args, + ), + ) + else: + # Simply call main_worker function + main_worker(args.gpu, ngpus_per_node, args) + + +def main_worker(gpu, ngpus_per_node, args): + global best_acc1 + print("spawn performed, gpu", gpu, flush=True) + args.gpu = gpu + + # suppress printing if not master + if args.multiprocessing_distributed and args.gpu != 0: + + def print_pass(*args, flush=True): + pass + + builtins.print = print_pass + + if args.gpu is not None: + print("Use GPU: {} for training".format(args.gpu), flush=True) + + if args.distributed: + print("here", flush=True) + if args.dist_url == "env://" and args.rank == -1: + args.rank = int(os.environ["RANK"]) + print("rank", args.rank, flush=True) + if args.multiprocessing_distributed: + # For multiprocessing distributed training, rank needs to be the + # global rank among all the processes + args.rank = args.rank * ngpus_per_node + gpu + print("second rank", args.rank, flush=True) + dist.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + print("init_process_group", flush=True) + torch.distributed.barrier() + + + # create model + print("=> creating model", flush=True) + # model = models.__dict__[args.arch]() + + # model.fc = nn.Linear(2048, args.num_classes) + # Load the pre-trained CLIP model + model, _ = clip.load("ViT-B/32") + model = model.float() + model = CLIPClassifier(model, args.num_classes) + args.start_epoch = 0 + + print("model", model.state_dict().keys(), flush=True) + + # # freeze all layers but the last fc + # for name, param in model.named_parameters(): + # if name not in ["fc.weight", "fc.bias"]: + # param.requires_grad = False + # # init the fc layer + # model.fc.weight.data.normal_(mean=0.0, std=0.01) + # model.fc.bias.data.zero_() + + # load from pre-trained, before DistributedDataParallel constructor + # if args.pretrained: + # if os.path.isfile(args.pretrained): + # print("=> loading checkpoint '{}'".format(args.pretrained), flush=True) + # checkpoint = torch.load(args.pretrained, map_location="cpu") + + # # rename moco pre-trained keys + # if args.ablation_mode == "icgan": + # state_dict = checkpoint + # else: + # state_dict = checkpoint["state_dict"] + # for k in list(state_dict.keys()): + # # retain only encoder up to before the embedding layer + # if args.ablation_mode == "icgan": + # if k.startswith("module") and not k.startswith( + # "module.fc" + # ): + # # remove prefix + # state_dict[k[len("module.") :]] = state_dict[k] + # # delete renamed or unused k + # del state_dict[k] + # else: + # if k.startswith("module.encoder") and not k.startswith( + # "module.encoder.fc" + # ): + # # remove prefix + # state_dict[k[len("module.encoder.") :]] = state_dict[k] + # # delete renamed or unused k + # del state_dict[k] + # msg = model.load_state_dict(state_dict, strict=False) + # assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} + + # print("=> loaded pre-trained model '{}'".format(args.pretrained)) + # else: + # print("=> no checkpoint found at '{}'".format(args.pretrained)) + + # infer learning rate before changing batch size + init_lr = args.lr * args.batch_size / 256 + + if args.distributed: + # For multiprocessing distributed, DistributedDataParallel constructor + # should always set the single device scope, otherwise, + # DistributedDataParallel will use all available devices. + if args.gpu is not None: + torch.cuda.set_device(args.gpu) + model.cuda(args.gpu) + # When using a single GPU per process and per + # DistributedDataParallel, we need to divide the batch size + # ourselves based on the total number of GPUs we have + args.batch_size = int(args.batch_size / ngpus_per_node) + print("batchsize", args.batch_size, flush=True) + args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) + print("workers", args.workers, flush=True) + print("gpu", args.gpu, flush=True) + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[args.gpu] + ) + else: + model.cuda() + # DistributedDataParallel will divide and allocate batch_size to all + # available GPUs if device_ids are not set + model = torch.nn.parallel.DistributedDataParallel(model) + elif args.gpu is not None: + torch.cuda.set_device(args.gpu) + model = model.cuda(args.gpu) + else: + # DataParallel will divide and allocate batch_size to all available GPUs + # if args.arch.startswith("alexnet") or args.arch.startswith("vgg"): + # model.features = torch.nn.DataParallel(model.features) + # model.cuda() + # else: + model = torch.nn.DataParallel(model).cuda() + + # define loss function (criterion) and optimizer + criterion = nn.CrossEntropyLoss().cuda(args.gpu) + + # optimize only the linear classifier + parameters = list(filter(lambda p: p.requires_grad, model.parameters())) + assert len(parameters) == 2 # fc.weight, fc.bias + + optimizer = torch.optim.SGD( + parameters, init_lr, momentum=args.momentum, weight_decay=args.weight_decay + ) + if args.lars: + print("=> use LARS optimizer.", flush=True) + from LARC import LARC + + optimizer = LARC(optimizer=optimizer, trust_coefficient=0.001, clip=False) + + # optionally resume from a checkpoint + if args.resume: + if os.path.isfile(args.resume): + print("=> loading checkpoint '{}'".format(args.resume), flush=True) + if args.gpu is None: + checkpoint = torch.load(args.resume) + else: + # Map model to be loaded to specified single gpu. + loc = "cuda:{}".format(args.gpu) + checkpoint = torch.load(args.resume, map_location=loc) + args.start_epoch = checkpoint["epoch"] + best_acc1 = checkpoint["best_acc1"] + if args.gpu is not None: + # best_acc1 may be from a checkpoint from a different GPU + best_acc1 = best_acc1.to(args.gpu) + model.load_state_dict(checkpoint["state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + print( + "=> loaded checkpoint '{}' (epoch {})".format( + args.resume, checkpoint["epoch"] + ), + flush=True, + ) + else: + print("=> no checkpoint found at '{}'".format(args.resume)) + + cudnn.benchmark = True + + # Data loading code + traindir = os.path.join(args.data, "train") + valdir = os.path.join(args.data, "val") + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + if args.dataset_name == "imagenet": + train_dataset = datasets.ImageFolder( + traindir, + transforms.Compose( + [ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ] + ), + ) + val_dataset = datasets.ImageFolder( + valdir, + transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ] + ), + ) + elif args.dataset_name == "food101": + print("=> using food101 dataset.", flush=True) + train_dataset = datasets.Food101( + root=args.data, + split="train", + transform=transforms.Compose( + [ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ], + ), + ) + val_dataset = datasets.Food101( + root=args.data, + split="test", + transform=transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ], + ), + ) + elif args.dataset_name == "cifar10": + train_dataset = datasets.CIFAR10( + root=args.data, + train=True, + download=True, + transform=transforms.Compose( + [ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ], + ), + ) + val_dataset = datasets.CIFAR10( + root=args.data, + train=False, + download=True, + transform=transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ], + ), + ) + elif args.dataset_name == "cifar100": + train_dataset = datasets.CIFAR100( + root=args.data, + train=True, + transform=transforms.Compose( + [ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ], + ), + ) + val_dataset = datasets.CIFAR100( + root=args.data, + train=False, + transform=transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ], + ), + ) + elif args.dataset_name == "places365": + train_dataset = datasets.Places365( + root=args.data, + split="train-standard", + transform=transforms.Compose( + [ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ], + ), + ) + val_dataset = datasets.Places365( + root=args.data, + split="val", + transform=transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ], + ), + ) + elif args.dataset_name == "INaturalist": + train_dataset = INAT( + root=args.data, + ann_file=os.path.join(args.data, "train2018.json"), + transform=transforms.Compose( + [ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ], + ), + ) + val_dataset = INAT( + root=args.data, + ann_file=os.path.join(args.data, "val2018.json"), + transform=transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ], + ), + ) + + if args.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + else: + train_sampler = None + + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.batch_size, + shuffle=(train_sampler is None), + num_workers=args.workers, + pin_memory=True, + sampler=train_sampler, + ) + + val_loader = torch.utils.data.DataLoader( + val_dataset, + batch_size=256, + shuffle=False, + num_workers=args.workers, + pin_memory=True, + ) + + if args.evaluate: + validate(val_loader, model, criterion, args) + return + + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + train_sampler.set_epoch(epoch) + adjust_learning_rate(optimizer, init_lr, epoch, args) + + print("epoch", epoch, flush=True) + + # train for one epoch + train(train_loader, model, criterion, optimizer, epoch, args) + + # evaluate on validation set + acc1 = validate(val_loader, model, criterion, args) + + # remember best acc@1 and save checkpoint + is_best = acc1 > best_acc1 + best_acc1 = max(acc1, best_acc1) + + if not args.multiprocessing_distributed or ( + args.multiprocessing_distributed and args.rank % ngpus_per_node == 0 + ): + checkpoint_name = "checkpoint_{:04d}.pth.tar".format(epoch + 1) + checkpoint_file = os.path.join(args.checkpoint_dir, checkpoint_name) + save_checkpoint( + { + "epoch": epoch + 1, + "state_dict": model.state_dict(), + "best_acc1": best_acc1, + "optimizer": optimizer.state_dict(), + }, + is_best, + filename=checkpoint_file, + ) + # if epoch == args.start_epoch: + # sanity_check(model.state_dict(), args.pretrained, args.ablation_mode) + + +def train(train_loader, model, criterion, optimizer, epoch, args): + batch_time = AverageMeter("Time", ":6.3f") + data_time = AverageMeter("Data", ":6.3f") + losses = AverageMeter("Loss", ":.4e") + top1 = AverageMeter("Acc@1", ":6.2f") + top5 = AverageMeter("Acc@5", ":6.2f") + progress = ProgressMeter( + len(train_loader), + [batch_time, data_time, losses, top1, top5], + prefix="Epoch: [{}]".format(epoch), + ) + + """ + Switch to eval mode: + Under the protocol of linear classification on frozen features/models, + it is not legitimate to change any part of the pre-trained model. + BatchNorm in train mode may revise running mean/std (even if it receives + no gradient), which are part of the model parameters too. + """ + model.eval() + + end = time.time() + i = 0 + for images, target in tqdm(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + if args.gpu is not None: + images = images.cuda(args.gpu, non_blocking=True) + target = target.cuda(args.gpu, non_blocking=True) + + # compute output + output = model(images) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), images.size(0)) + top1.update(acc1[0], images.size(0)) + top5.update(acc5[0], images.size(0)) + + if i == 0: + print("first step passed", flush=True) + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + progress.display(i) + + i += 1 + + +def validate(val_loader, model, criterion, args): + batch_time = AverageMeter("Time", ":6.3f") + losses = AverageMeter("Loss", ":.4e") + top1 = AverageMeter("Acc@1", ":6.2f") + top5 = AverageMeter("Acc@5", ":6.2f") + progress = ProgressMeter( + len(val_loader), [batch_time, losses, top1, top5], prefix="Test: " + ) + + # switch to evaluate mode + model.eval() + + with torch.no_grad(): + end = time.time() + i = 0 + for images, target in tqdm(val_loader): + if args.gpu is not None: + images = images.cuda(args.gpu, non_blocking=True) + target = target.cuda(args.gpu, non_blocking=True) + + # compute output + output = model(images) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), images.size(0)) + top1.update(acc1[0], images.size(0)) + top5.update(acc5[0], images.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + progress.display(i) + + i += 1 + + # # TODO: this should also be done with the ProgressMeter + print( + "\n * Accuracy@1 {top1.avg:.3f} Accuracy@5 {top5.avg:.3f}".format( + top1=top1, top5=top5 + ) + ) + + return top1.avg + + +def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"): + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, "model_best.pth.tar") + + +# def sanity_check(state_dict, pretrained_weights, ablation_mode): +# """ +# Linear classifier should not change any weights other than the linear layer. +# This sanity check asserts nothing wrong happens (e.g., BN stats updated). +# """ +# print("=> loading '{}' for sanity check".format(pretrained_weights)) +# checkpoint = torch.load(pretrained_weights, map_location="cpu") +# if ablation_mode == "icgan": +# state_dict_pre = checkpoint +# else: +# state_dict_pre = checkpoint["state_dict"] + +# for k in list(state_dict.keys()): +# # only ignore fc layer +# if "fc.weight" in k or "fc.bias" in k: +# continue +# if ablation_mode == "icgan": +# # name in pretrained model +# k_pre = ( +# "module." + k[len("module.") :] +# if k.startswith("module.") +# else "module." + k +# ) + +# else: +# # name in pretrained model +# k_pre = ( +# "module.encoder." + k[len("module.") :] +# if k.startswith("module.") +# else "module.encoder." + k +# ) + +# assert ( +# state_dict[k].cpu() == state_dict_pre[k_pre] +# ).all(), "{} is changed in linear classifier training.".format(k) + +# print("=> sanity check passed.") + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=":f"): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" + return fmtstr.format(**self.__dict__) + + +class ProgressMeter(object): + def __init__(self, num_batches, meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def display(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print("\t".join(entries), flush=True) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = "{:" + str(num_digits) + "d}" + return "[" + fmt + "/" + fmt.format(num_batches) + "]" + + +def adjust_learning_rate(optimizer, init_lr, epoch, args): + """Decay the learning rate based on schedule""" + cur_lr = init_lr * 0.5 * (1.0 + math.cos(math.pi * epoch / args.epochs)) + for param_group in optimizer.param_groups: + param_group["lr"] = cur_lr + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +if __name__ == "__main__": + main() diff --git a/simsiam/linear_eval_original_logs.py b/simsiam/linear_eval_original_logs.py new file mode 100755 index 0000000..a3ced4e --- /dev/null +++ b/simsiam/linear_eval_original_logs.py @@ -0,0 +1,517 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import argparse +import math +import os +import random +import shutil +import time +from functools import partial + +import torch +import torch.nn.parallel +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +from torch import nn +from torch.backends import cudnn +from torch.nn.parallel import DistributedDataParallel as DDP # noqa: N817 +from torchvision import datasets, models, transforms +from tqdm import tqdm + +from SimCLR import distributed as dist_utils +from LARC import LARC + + +model_names = sorted( + name + for name in models.__dict__ + if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) +) + +parser = argparse.ArgumentParser(description="PyTorch ImageNet Training") +parser.add_argument( + "--data_dir", + metavar="DIR", + default="/scratch/ssd004/datasets/imagenet256", + help="path to dataset.", +) +parser.add_argument( + "-a", + "--arch", + metavar="ARCH", + default="resnet50", + choices=model_names, + help="model architecture: " + " | ".join(model_names) + " (default: resnet50)", +) +parser.add_argument( + "-j", + "--num_workers", + default=4, + type=int, + metavar="N", + help="number of data loading workers (default: 32)", +) +parser.add_argument( + "--epochs", default=90, type=int, metavar="N", help="number of total epochs to run" +) +parser.add_argument( + "-b", + "--batch-size", + default=4096, + type=int, + metavar="N", + help="mini-batch size (default: 4096), this is the total " + "batch size of all GPUs on the current node when " + "using Data Parallel or Distributed Data Parallel", +) +parser.add_argument( + "--lr", + "--learning-rate", + default=0.1, + type=float, + metavar="LR", + help="initial (base) learning rate", + dest="lr", +) +parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") +parser.add_argument( + "--wd", + "--weight-decay", + default=0.0, + type=float, + metavar="W", + help="weight decay (default: 0.)", + dest="weight_decay", +) +parser.add_argument( + "--distributed_mode", + action="store_true", + help="Enable distributed training", +) +parser.add_argument("--distributed_launcher", default="slurm") +parser.add_argument("--distributed_backend", default="nccl") +parser.add_argument( + "--seed", default=42, type=int, help="seed for initializing training. " +) +parser.add_argument( + "--pretrained_checkpoint", + default="", + type=str, + help="Path to simsiam pretrained checkpoint.", +) +parser.add_argument("--lars", action="store_true", help="Use LARS") +parser.add_argument( + "--checkpoint_dir", + default="", + help="Checkpoint directory to save eval model checkpoints.", +) +parser.add_argument( + "-p", + "--print-freq", + default=10, + type=int, + metavar="N", + help="print frequency (default: 10)", +) + +best_acc1 = 0 + + +def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int) -> None: + """Initialize worker processes with a random seed. + + Parameters + ---------- + worker_id : int + ID of the worker process. + num_workers : int + Total number of workers that will be initialized. + rank : int + The rank of the current process. + seed : int + A random seed used determine the worker seed. + """ + worker_seed = num_workers * rank + worker_id + seed + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + +def main(): + args = parser.parse_args() + global best_acc1 + + torch.multiprocessing.set_start_method("spawn") + if args.distributed_mode: + dist_utils.init_distributed_mode( + launcher=args.distributed_launcher, + backend=args.distributed_backend, + ) + device_id = torch.cuda.current_device() + else: + device_id = None + + # create model + print(f"Creating model {args.arch}") + model = models.__dict__[args.arch]() + + # freeze all layers but the last fc + for name, param in model.named_parameters(): + if name not in ["fc.weight", "fc.bias"]: + param.requires_grad = False + # init the fc layer + model.fc.weight.data.normal_(mean=0.0, std=0.01) + model.fc.bias.data.zero_() + + # load from pre-trained, before DistributedDataParallel constructor + if args.pretrained_checkpoint: + if os.path.isfile(args.pretrained_checkpoint): + print(f"Loading checkpoint {args.pretrained_checkpoint}") + checkpoint = torch.load(args.pretrained_checkpoint, map_location="cpu") + + # rename moco pre-trained keys + state_dict = checkpoint["state_dict"] + for k in list(state_dict.keys()): + # retain only encoder up to before the embedding layer + if k.startswith("module.encoder") and not k.startswith( + "module.encoder.fc" + ): + # remove prefix + state_dict[k[len("module.encoder.") :]] = state_dict[k] + # delete renamed or unused k + del state_dict[k] + + msg = model.load_state_dict(state_dict, strict=False) + assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} + else: + raise ValueError(f"No checkpoint found at: {args.pretrained_checkpoint}") + + # infer learning rate before changing batch size + init_lr = args.lr * args.batch_size / 256 + + if args.distributed_mode and dist_utils.is_dist_avail_and_initialized(): + torch.cuda.set_device(device_id) + model = model.cuda(device_id) + model = DDP(model, device_ids=[device_id]) + else: + raise NotImplementedError("Only DistributedDataParallel is supported.") + + # define loss function (criterion) and optimizer + criterion = nn.CrossEntropyLoss().cuda(device_id) + + # optimize only the linear classifier + parameters = list(filter(lambda p: p.requires_grad, model.parameters())) + assert len(parameters) == 2 # fc.weight, fc.bias + + # TODO(arashaf): Enable Adam optimizer + optimizer = torch.optim.SGD( + parameters, + init_lr, + momentum=args.momentum, + weight_decay=args.weight_decay, + ) + if args.lars: + print("Use LARS optimizer.") + LARC(optimizer=optimizer, trust_coefficient=0.001, clip=False) + + cudnn.benchmark = True + + # Data loading code + train_dir = os.path.join(args.data_dir, "train") + val_dir = os.path.join(args.data_dir, "val") + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + train_dataset = datasets.ImageFolder( + train_dir, + transforms.Compose( + [ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ], + ), + ) + + if dist_utils.is_dist_avail_and_initialized() and args.distributed_mode: + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, + seed=args.seed, + ) + else: + train_sampler = None + + init_fn = partial( + worker_init_fn, + num_workers=args.num_workers, + rank=dist_utils.get_rank(), + seed=args.seed, + ) + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.batch_size, + shuffle=(train_sampler is None), + sampler=train_sampler, + num_workers=args.num_workers, + worker_init_fn=init_fn, + pin_memory=True, # TODO(arashaf): this was set to false in training script. + ) + + val_dataset = datasets.ImageFolder( + val_dir, + transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ], + ), + ) + val_loader = torch.utils.data.DataLoader( + val_dataset, + batch_size=256, + shuffle=False, + num_workers=args.num_workers, + pin_memory=True, + ) + + for epoch in tqdm(range(args.epochs)): + print(f"Starting training epoch: {epoch}") + if dist_utils.is_dist_avail_and_initialized(): + train_sampler.set_epoch(epoch) + adjust_learning_rate(optimizer, init_lr, epoch, args) + + # train for one epoch + train(train_loader, model, criterion, optimizer, epoch, device_id, args) + + # evaluate on validation set + acc1 = validate(val_loader, model, criterion, device_id, args) + + # remember best acc@1 and save checkpoint + is_best = acc1 > best_acc1 + best_acc1 = max(acc1, best_acc1) + + if args.checkpoint_dir and dist_utils.get_rank() == 0: + os.makedirs(args.checkpoint_dir, exist_ok=True) + checkpoint_name = "eval_checkpoint_{:04d}.pth.tar".format(epoch) + checkpoint_file = os.path.join(args.checkpoint_dir, checkpoint_name) + save_checkpoint( + { + "epoch": epoch, + "arch": args.arch, + "state_dict": model.state_dict(), + "best_acc1": best_acc1, + "optimizer": optimizer.state_dict(), + }, + is_best, + checkpoint_file, + ) + if epoch == 0: + sanity_check(model.state_dict(), args.pretrained_checkpoint) + + +def train(train_loader, model, criterion, optimizer, epoch, device_id, args): + batch_time = AverageMeter("Time", ":6.3f") + data_time = AverageMeter("Data", ":6.3f") + losses = AverageMeter("Loss", ":.4e") + top1 = AverageMeter("Acc@1", ":6.2f") + top5 = AverageMeter("Acc@5", ":6.2f") + progress = ProgressMeter( + len(train_loader), + [batch_time, data_time, losses, top1, top5], + prefix="Epoch: [{}]".format(epoch), + ) + """ + Switch to eval mode: + Under the protocol of linear classification on frozen features/models, + it is not legitimate to change any part of the pre-trained model. + BatchNorm in train mode may revise running mean/std (even if it receives + no gradient), which are part of the model parameters too. + """ + model.eval() + + end = time.time() + for i, (images, target) in enumerate(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + images = images.cuda(device_id, non_blocking=True) + target = target.cuda(device_id, non_blocking=True) + + # compute output + output = model(images) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), images.size(0)) + top1.update(acc1[0], images.size(0)) + top5.update(acc5[0], images.size(0)) + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + progress.display(i) + + +def validate(val_loader, model, criterion, device_id, args): + batch_time = AverageMeter("Time", ":6.3f") + losses = AverageMeter("Loss", ":.4e") + top1 = AverageMeter("Acc@1", ":6.2f") + top5 = AverageMeter("Acc@5", ":6.2f") + progress = ProgressMeter( + len(val_loader), [batch_time, losses, top1, top5], prefix="Test: " + ) + + # switch to evaluate mode + model.eval() + + with torch.no_grad(): + end = time.time() + for i, (images, target) in enumerate(val_loader): + images = images.cuda(device_id, non_blocking=True) + target = target.cuda(device_id, non_blocking=True) + + # compute output + output = model(images) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), images.size(0)) + top1.update(acc1[0], images.size(0)) + top5.update(acc5[0], images.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + progress.display(i) + + print( + "Validation Accuracy@1 {top1.avg:.3f}, Accuracy@5 {top5.avg:.3f}".format( + top1=top1, + top5=top5, + ) + ) + + return top1.avg + + +def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"): + print(f"Saving checkpoint at: {filename}") + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, "model_best.pth.tar") + + +def sanity_check(state_dict, pretrained_weights): + """ + Linear classifier should not change any weights other than the linear layer. + This sanity check asserts nothing wrong happens (e.g., BN stats updated). + """ + print(f"Loading {pretrained_weights} for sanity check") + checkpoint = torch.load(pretrained_weights, map_location="cpu") + state_dict_pre = checkpoint["state_dict"] + + for k in list(state_dict.keys()): + # only ignore fc layer + if "fc.weight" in k or "fc.bias" in k: + continue + + # name in pretrained model + k_pre = ( + "module.encoder." + k[len("module.") :] + if k.startswith("module.") + else "module.encoder." + k + ) + + assert ( + state_dict[k].cpu() == state_dict_pre[k_pre] + ).all(), "{} is changed in linear classifier training.".format(k) + + print("Sanity check passed.") + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=":f"): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" + return fmtstr.format(**self.__dict__) + + +class ProgressMeter(object): + def __init__(self, num_batches, meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def display(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print("\t".join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = "{:" + str(num_digits) + "d}" + return "[" + fmt + "/" + fmt.format(num_batches) + "]" + + +def adjust_learning_rate(optimizer, init_lr, epoch, args): + """Decay the learning rate based on schedule""" + cur_lr = init_lr * 0.5 * (1.0 + math.cos(math.pi * epoch / args.epochs)) + for param_group in optimizer.param_groups: + param_group["lr"] = cur_lr + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +if __name__ == "__main__": + main() diff --git a/simsiam/loader.py b/simsiam/loader.py new file mode 100644 index 0000000..69a33ac --- /dev/null +++ b/simsiam/loader.py @@ -0,0 +1,126 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import random + +import torch +from PIL import Image, ImageFilter +from torchvision import datasets, transforms + + +class GaussianBlur(object): + """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709.""" + + def __init__(self, sigma=[0.1, 2.0]): + self.sigma = sigma + + def __call__(self, x): + sigma = random.uniform(self.sigma[0], self.sigma[1]) + x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) + return x + + +_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + +# MoCo v2's aug: similar to SimCLR https://arxiv.org/abs/2002.05709 +_real_augmentations = [ + transforms.RandomResizedCrop(224, scale=(0.2, 1.0)), + transforms.RandomApply( + [ + transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened + ], + p=0.8, + ), + transforms.RandomGrayscale(p=0.2), + transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=0.5), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + _normalize, +] + + +class TwoCropsTransform: + """Take two random crops of one image as the query and key.""" + + def __init__(self): + self.base_transform = transforms.Compose(_real_augmentations) + + def __call__(self, x): + q = self.base_transform(x) + k = self.base_transform(x) + return [q, k] + + +class ImageNetSynthetic(datasets.ImageNet): + def __init__( + self, + imagenet_root, + imagenet_synthetic_root, + index_min=0, + index_max=9, + generative_augmentation_prob=None, + load_one_real_image=False, + split="train", + ): + super(ImageNetSynthetic, self).__init__( + root=imagenet_root, + split=split, + ) + self.imagenet_root = imagenet_root + self.imagenet_synthetic_root = imagenet_synthetic_root + self.index_min = index_min + self.index_max = index_max + self.generative_augmentation_prob = generative_augmentation_prob + self.load_one_real_image = load_one_real_image + self.real_transforms = transforms.Compose(_real_augmentations) + # Remove random crop for synthetic image augmentation. + self.synthetic_transforms = transforms.Compose(_real_augmentations[1:]) + self.split = split + + def __getitem__(self, index): + imagenet_filename, label = self.imgs[index] + + def _synthetic_image(filename): + rand_int = random.randint(self.index_min, self.index_max) + filename_and_extension = filename.split("/")[-1] + filename_parent_dir = filename.split("/")[-2] + image_path = os.path.join( + self.imagenet_synthetic_root, + self.split, + filename_parent_dir, + filename_and_extension.split(".")[0] + f"_{rand_int}.JPEG", + ) + return Image.open(image_path).convert("RGB") + + if self.generative_augmentation_prob is not None: + if torch.rand(1) < self.generative_augmentation_prob: + # Generate a synthetic image. + image1 = _synthetic_image(imagenet_filename) + image1 = self.synthetic_transforms(image1) + else: + image1 = self.loader(os.path.join(self.root, imagenet_filename)) + image1 = self.real_transforms(image1) + + if torch.rand(1) < self.generative_augmentation_prob: + # Generate another synthetic image. + image2 = _synthetic_image(imagenet_filename) + image2 = self.synthetic_transforms(image2) + else: + image2 = self.loader(os.path.join(self.root, imagenet_filename)) + image2 = self.real_transforms(image2) + else: + if self.load_one_real_image: + image1 = self.loader(os.path.join(self.root, imagenet_filename)) + image1 = self.real_transforms(image1) + else: + image1 = _synthetic_image(imagenet_filename) + image1 = self.synthetic_transforms(image1) + # image2 is always synthetic. + image2 = _synthetic_image(imagenet_filename) + image2 = self.synthetic_transforms(image2) + + return [image1, image2], label diff --git a/simsiam/main_simsiam.py b/simsiam/main_simsiam.py new file mode 100755 index 0000000..56da4ea --- /dev/null +++ b/simsiam/main_simsiam.py @@ -0,0 +1,422 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import math +import os +import random +from datetime import datetime +from functools import partial + +import torch +import torch.nn.parallel +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +from torch import nn +from torch.backends import cudnn +from torch.nn.parallel import DistributedDataParallel as DDP # noqa: N817 +from torch.utils.data.distributed import DistributedSampler +from torchvision import datasets, models +from tqdm import tqdm + +from SimCLR import distributed as dist_utils +from simsiam import builder, loader + + +model_names = sorted( + name + for name in models.__dict__ + if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) +) + +parser = argparse.ArgumentParser(description="PyTorch ImageNet Training") +parser.add_argument( + "--data_dir", + metavar="DIR", + default="/scratch/ssd004/datasets/imagenet256", + help="path to dataset.", +) +parser.add_argument( + "-a", + "--arch", + metavar="ARCH", + default="resnet50", + choices=model_names, + help="model architecture: " + " | ".join(model_names) + " (default: resnet50)", +) +parser.add_argument( + "-j", + "--num_workers", + default=4, + type=int, + metavar="N", + help="number of data loading workers (default: 32)", +) +parser.add_argument( + "--epochs", default=100, type=int, metavar="N", help="number of total epochs to run" +) +parser.add_argument( + "-b", + "--batch-size", + default=256, + type=int, + metavar="N", + help="mini-batch size (default: 512), this is the total " + "batch size of all GPUs on the current node when " + "using Data Parallel or Distributed Data Parallel", +) +parser.add_argument( + "--lr", + "--learning-rate", + default=0.05, + type=float, + metavar="LR", + help="initial (base) learning rate", + dest="lr", +) +parser.add_argument( + "--momentum", default=0.9, type=float, metavar="M", help="momentum of SGD solver" +) +parser.add_argument( + "--wd", + "--weight-decay", + default=1e-4, + type=float, + metavar="W", + help="weight decay (default: 1e-4)", + dest="weight_decay", +) +parser.add_argument( + "--resume_from_checkpoint", + default="", + type=str, + help="Path to latest checkpoint.", +) +parser.add_argument( + "--seed", default=42, type=int, help="seed for initializing training. " +) + +# simsiam specific configs: +parser.add_argument( + "--dim", default=2048, type=int, help="feature dimension (default: 2048)" +) +parser.add_argument( + "--pred-dim", + default=512, + type=int, + help="hidden dimension of the predictor (default: 512)", +) +parser.add_argument( + "--fix-pred-lr", action="store_true", help="Fix learning rate for the predictor" +) + +parser.add_argument( + "--distributed_mode", + action="store_true", + help="Enable distributed training", +) +parser.add_argument("--distributed_launcher", default="slurm") +parser.add_argument("--distributed_backend", default="nccl") +parser.add_argument( + "--checkpoint_dir", + default="/projects/imagenet_synthetic/model_checkpoints", + help="Checkpoint root directory.", +) +parser.add_argument( + "--experiment", + default="", + help="Experiment name.", +) +parser.add_argument( + "--use_synthetic_data", + action=argparse.BooleanOptionalAction, + help="Whether to use real data or synthetic data for training.", +) +parser.add_argument( + "--synthetic_data_dir", + default="/projects/imagenet_synthetic/", + help="Path to the root of synthetic data.", +) +parser.add_argument( + "--synthetic_index_min", + default=0, + type=int, + help="Synthetic data files are named filename_i.JPEG. This index determines the lower bound for i.", +) +parser.add_argument( + "--synthetic_index_max", + default=9, + type=int, + help="Synthetic data files are named filename_i.JPEG. This index determines the upper bound for i.", +) +parser.add_argument( + "--generative_augmentation_prob", + default=None, + type=float, + help="The probability of applying a generative model augmentation to a view. Applies to the views separately.", +) +parser.add_argument( + "-p", + "--print-freq", + default=10, + type=int, + metavar="N", + help="print frequency (default: 10)", +) + + +def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int) -> None: + """Initialize worker processes with a random seed. + + Parameters + ---------- + worker_id : int + ID of the worker process. + num_workers : int + Total number of workers that will be initialized. + rank : int + The rank of the current process. + seed : int + A random seed used determine the worker seed. + """ + worker_seed = num_workers * rank + worker_id + seed + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + +def main(): + args = parser.parse_args() + current_time = datetime.now().strftime("%Y-%m-%d-%H-%M") + checkpoint_subdir = ( + f"{args.experiment}_{current_time}" if args.experiment else f"{current_time}" + ) + args.checkpoint_dir = os.path.join(args.checkpoint_dir, checkpoint_subdir) + os.makedirs(args.checkpoint_dir, exist_ok=True) + + print(args) + + torch.multiprocessing.set_start_method("spawn") + if args.distributed_mode: + dist_utils.init_distributed_mode( + launcher=args.distributed_launcher, + backend=args.distributed_backend, + ) + device_id = torch.cuda.current_device() + else: + device_id = None + + # Data loading. + if args.use_synthetic_data: + print( + f"Using synthetic data for training at {args.synthetic_data_dir} between indices {args.synthetic_index_min} and {args.synthetic_index_max}." + ) + train_dataset = loader.ImageNetSynthetic( + args.data_dir, + args.synthetic_data_dir, + index_min=args.synthetic_index_min, + index_max=args.synthetic_index_max, + generative_augmentation_prob=args.generative_augmentation_prob, + ) + else: + print(f"Using real data for training at {args.data_dir}.") + train_data_dir = os.path.join(args.data_dir, "train") + train_dataset = datasets.ImageFolder(train_data_dir, loader.TwoCropsTransform()) + + train_sampler = None + if dist_utils.is_dist_avail_and_initialized() and args.distributed_mode: + train_sampler = DistributedSampler( + train_dataset, + seed=args.seed, + drop_last=True, + ) + init_fn = partial( + worker_init_fn, + num_workers=args.num_workers, + rank=dist_utils.get_rank(), + seed=args.seed, + ) + + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.batch_size, + shuffle=(train_sampler is None), + sampler=train_sampler, + num_workers=args.num_workers, + worker_init_fn=init_fn, + pin_memory=False, + drop_last=True, + ) + + print(f"Creating model {args.arch}") + model = builder.SimSiam(models.__dict__[args.arch], args.dim, args.pred_dim) + + if args.distributed_mode and dist_utils.is_dist_avail_and_initialized(): + # Apply SyncBN + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + # set the single device scope, otherwise DistributedDataParallel will + # use all available devices + torch.cuda.set_device(device_id) + model = model.cuda(device_id) + model = DDP(model, device_ids=[device_id]) + else: + raise NotImplementedError("Only DistributedDataParallel is supported.") + print(model) # print model after SyncBatchNorm + + # define loss function (criterion) and optimizer + criterion = nn.CosineSimilarity(dim=1).cuda(device_id) + + if args.fix_pred_lr: + optim_params = [ + {"params": model.module.encoder.parameters(), "fix_lr": False}, + {"params": model.module.predictor.parameters(), "fix_lr": True}, + ] + else: + optim_params = model.parameters() + + # infer learning rate before changing batch size + # init_lr = args.lr * args.batch_size / 256.0 + # TODO(arashaf): Hard-code init-lr to match the original paper with bs=512. + init_lr = args.lr * 2.0 + + optimizer = torch.optim.SGD( + optim_params, + init_lr, + momentum=args.momentum, + weight_decay=args.weight_decay, + ) + + start_epoch = 0 + # Optionally resume from a checkpoint + if args.resume_from_checkpoint: + if os.path.isfile(args.resume_from_checkpoint): + print(f"Loading checkpoint: {args.resume_from_checkpoint}") + checkpoint = torch.load(args.resume_from_checkpoint) + start_epoch = checkpoint["epoch"] + 1 + model.load_state_dict(checkpoint["state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + print(f"Loaded checkpoint {args.resume_from_checkpoint} successfully.") + else: + raise ValueError(f"No checkpoint found at: {args.resume_from_checkpoint}") + + cudnn.benchmark = True + + for epoch in range(start_epoch, args.epochs): + print(f"Starting training epoch: {epoch}") + if dist_utils.is_dist_avail_and_initialized(): + train_sampler.set_epoch(epoch) + adjust_learning_rate(optimizer, init_lr, epoch, args) + + # train for one epoch + train(train_loader, model, criterion, optimizer, epoch, device_id, args) + + # Checkpointing. + if dist_utils.get_rank() == 0: + checkpoint_name = "checkpoint_{:04d}.pth.tar".format(epoch) + checkpoint_file = os.path.join(args.checkpoint_dir, checkpoint_name) + save_checkpoint( + { + "epoch": epoch, + "arch": args.arch, + "state_dict": model.state_dict(), + "optimizer": optimizer.state_dict(), + }, + filename=checkpoint_file, + ) + + +def train(train_loader, model, criterion, optimizer, epoch, device_id, args): + """Single epoch training code.""" + losses = AverageMeter("Loss", ":.4f") + progress = ProgressMeter( + len(train_loader), + [losses], + prefix="Epoch: [{}]".format(epoch), + ) + + # switch to train mode + model.train() + + for i, (images, _) in enumerate(train_loader): + # for images, _ in tqdm(train_loader): + images[0] = images[0].cuda(device_id, non_blocking=True) + images[1] = images[1].cuda(device_id, non_blocking=True) + + # compute output and loss + p1, p2, z1, z2 = model(x1=images[0], x2=images[1]) + loss = -(criterion(p1, z2).mean() + criterion(p2, z1).mean()) * 0.5 + + losses.update(loss.item(), images[0].size(0)) + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if i % args.print_freq == 0: + progress.display(i) + + +def save_checkpoint(state, filename="checkpoint.pth.tar"): + """Save state dictionary into a model checkpoint.""" + print(f"Saving checkpoint at: {filename}") + torch.save(state, filename) + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=":f"): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" + return fmtstr.format(**self.__dict__) + + +class ProgressMeter(object): + def __init__(self, num_batches, meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def display(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print("\t".join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = "{:" + str(num_digits) + "d}" + return "[" + fmt + "/" + fmt.format(num_batches) + "]" + + +def adjust_learning_rate(optimizer, init_lr, epoch, args): + """Decay the learning rate based on schedule.""" + cur_lr = init_lr * 0.5 * (1.0 + math.cos(math.pi * epoch / args.epochs)) + for param_group in optimizer.param_groups: + if "fix_lr" in param_group and param_group["fix_lr"]: + param_group["lr"] = init_lr + else: + param_group["lr"] = cur_lr + + +if __name__ == "__main__": + main() diff --git a/train_simclr.slrm b/train_simclr.slrm index 6a03dca..76800d6 100644 --- a/train_simclr.slrm +++ b/train_simclr.slrm @@ -1,6 +1,6 @@ #!/bin/bash -#SBATCH --job-name=simclr_base +#SBATCH --job-name=simclr_icgan #SBATCH --partition=a100 #SBATCH --qos=a100_arashaf #SBATCH --time=72:00:00 @@ -9,7 +9,7 @@ #SBATCH --ntasks-per-node=4 #SBATCH --cpus-per-task=4 #SBATCH --mem-per-cpu=4G -#SBATCH --output=slurm-%N-%j.out +#SBATCH --output=slurm-%j.out PY_ARGS=${@:1} @@ -31,4 +31,8 @@ srun python run_simCLR.py \ --distributed_mode \ --batch-size=512 \ --epochs=100 \ ---no-use_synthetic_data \ No newline at end of file +--use_synthetic_data \ +--synthetic_data_dir="/projects/imagenet_synthetic/synthetic_icgan" \ +--synthetic_index_min=0 \ +--synthetic_index_max=4 \ +--generative_augmentation_prob=0.5 diff --git a/train_simsiam.slrm b/train_simsiam.slrm new file mode 100644 index 0000000..dc8d23f --- /dev/null +++ b/train_simsiam.slrm @@ -0,0 +1,45 @@ +#!/bin/bash + +#SBATCH --job-name="simsiam_train" +#SBATCH --partition=a40 +#SBATCH --account=deadline +#SBATCH --qos=deadline +#SBATCH --time=72:00:00 +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --ntasks-per-node=4 +#SBATCH --cpus-per-task=4 +#SBATCH --mem-per-cpu=8G +#SBATCH --output=slurm-%j.out + +PY_ARGS=${@:1} + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +export CUDA_LAUNCH_BLOCKING=1 + +export MASTER_ADDR=$(hostname) +export MASTER_PORT=45679 + +export PYTHONPATH="." +nvidia-smi + +# “srun” executes the script times +srun python simsiam/main_simsiam.py \ +-a resnet50 \ +--fix-pred-lr \ +--distributed_mode \ +--batch-size=128 \ +--epochs=100 \ +--experiment="simsiam_stablediff_p0p5_seed43" \ +--resume_from_checkpoint="" \ +--seed=43 \ +--use_synthetic_data \ +--synthetic_data_dir="/projects/imagenet_synthetic/arashaf_stablediff_batched" \ +--synthetic_index_min=0 \ +--synthetic_index_max=1 \ +--generative_augmentation_prob=0.5 + + diff --git a/train_simsiam_multinode.slrm b/train_simsiam_multinode.slrm new file mode 100644 index 0000000..f90f3c8 --- /dev/null +++ b/train_simsiam_multinode.slrm @@ -0,0 +1,57 @@ +#!/bin/bash + +#SBATCH --job-name="simsiam_multi_train" +#SBATCH --partition=a40 +#SBATCH --account=deadline +#SBATCH --qos=deadline +#SBATCH --nodes=2 +#SBATCH --gres=gpu:a40:4 +#SBATCH --ntasks-per-node=1 +#SBATCH --open-mode=append +#SBATCH --wait-all-nodes=1 +#SBATCH --time=01:00:00 +#SBATCH --cpus-per-task=4 +#SBATCH --mem-per-cpu=8G +#SBATCH --output=slurm-%j.out + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export NCCL_IB_DISABLE=1 # Our cluster does not have InfiniBand. We need to disable usage using this flag. +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +# export CUDA_LAUNCH_BLOCKING=1 + + +export MASTER_ADDR="$(hostname --fqdn)" +export MASTER_PORT="$(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1])')" +export RDVZ_ID=$RANDOM +echo "RDZV Endpoint $MASTER_ADDR:$MASTER_PORT" + +export PYTHONPATH="." +nvidia-smi + +srun -p $SLURM_JOB_PARTITION \ + -c $SLURM_CPUS_ON_NODE \ + -N $SLURM_JOB_NUM_NODES \ + --mem=0 \ + --gres=gpu:$SLURM_JOB_PARTITION:$SLURM_GPUS_ON_NODE \ + bash -c 'torchrun \ + --nproc-per-node=$SLURM_GPUS_ON_NODE \ + --nnodes=$SLURM_JOB_NUM_NODES \ + --rdzv-endpoint $MASTER_ADDR:$MASTER_PORT \ + --rdzv-id $RDVZ_ID \ + --rdzv-backend c10d \ + simsiam/adil_main_simsiam.py \ + -a resnet50 \ + --fix-pred-lr \ + --distributed_mode \ + --batch-size=128 \ + --epochs=200 \ + --experiment="simsiam_icgan_seed43_bs128_rforig" \ + --resume_from_checkpoint="/projects/imagenet_synthetic/model_checkpoints/_original_simsiam/checkpoint_0099.pth.tar" \ + --seed=43 \ + --use_synthetic_data \ + --synthetic_data_dir="/projects/imagenet_synthetic/synthetic_icgan" \ + --synthetic_index_min=0 \ + --synthetic_index_max=4 \ + --generative_augmentation_prob=0.5' \ No newline at end of file diff --git a/train_simsiam_singlenode.slrm b/train_simsiam_singlenode.slrm new file mode 100644 index 0000000..b930947 --- /dev/null +++ b/train_simsiam_singlenode.slrm @@ -0,0 +1,40 @@ +#!/bin/bash + +#SBATCH --job-name="simsiam_single_train" +#SBATCH --partition=a40 +#SBATCH --qos=m +#SBATCH --nodes=1 +#SBATCH --gres=gpu:a40:4 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=32 +#SBATCH --mem=0 +#SBATCH --output=singlenode-%j.out +#SBATCH --error=singlenode-%j.err +#SBATCH --open-mode=append +#SBATCH --wait-all-nodes=1 +#SBATCH --time=12:00:00 + +# load virtual environment +source /ssd003/projects/aieng/envs/genssl2/bin/activate + +export NCCL_IB_DISABLE=1 # Our cluster does not have InfiniBand. We need to disable usage using this flag. +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend +# export CUDA_LAUNCH_BLOCKING=1 + +export PYTHONPATH="." +nvidia-smi + +torchrun --nproc-per-node=4 --nnodes=1 simsiam/adil_main_simsiam.py \ + -a resnet50 \ + --fix-pred-lr \ + --distributed_mode \ + --batch-size=128 \ + --epochs=100 \ + --experiment="simsiam_stablediff_p0p5_seed43" \ + --resume_from_checkpoint="" \ + --seed=43 \ + --use_synthetic_data \ + --synthetic_data_dir="/projects/imagenet_synthetic/arashaf_stablediff_batched" \ + --synthetic_index_min=0 \ + --synthetic_index_max=1 \ + --generative_augmentation_prob=0.5 \ No newline at end of file