Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions create_imagenet_icgan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import argparse
from torchvision import datasets, transforms
import time
import os
from icgan.config import get_config
from icgan.icgan_inference import ICGANInference
from pytorch_pretrained_biggan import convert_to_images
from icgan.data_utils import utils as data_utils
import torch
import numpy as np
from tqdm import tqdm

parser = argparse.ArgumentParser()
parser.add_argument(
"--outdir",
type=str,
nargs="?",
help="dir to write results to",
default="/projects/imagenet_synthetic/synthetic_icgan"
)
parser.add_argument(
"--img_save_size",
type=int,
default=224,
help="image saving size"
)
parser.add_argument(
"--start",
type=int,
default=0,
help="start index",
)
parser.add_argument(
"--end",
type=int,
default=-1,
help="end index",
)
parser.add_argument(
"--ith_sample",
type=int,
default=0,
help="end index",
)

args = parser.parse_args()

def save(out, torch_format=True):
if torch_format:
with torch.no_grad():
out = out.cpu().numpy()
img = convert_to_images(out)[0]
return img

def save_images(path, images, out_dir):
for img in images:
out_folder = path.split("/")[-1].split(".")[0].split("_")[0] # get the class name
file_name = path.split("/")[-1].split(".")[0] # get the (class name_image number)
save_folder = os.path.join(out_dir, out_folder) # create a folder for each class

if not os.path.exists(save_folder):
os.makedirs(save_folder, exist_ok=True)

save_file = os.path.join(save_folder, f"{file_name}_{args.ith_sample}.JPEG")
pil_img = save(img[np.newaxis, ...], torch_format=False)
pil_img.thumbnail((args.img_save_size, args.img_save_size))
pil_img.save(save_file, format="JPEG")

def main():
if args.outdir is not None:
os.makedirs(args.outdir, exist_ok=True)
# Initialize the ICGANInference class
config = get_config()
config.seed = args.ith_sample
icgan_inference = ICGANInference(config)
transform = transforms.Compose([
data_utils.CenterCropLongEdge(),
transforms.ToTensor(),
transforms.Normalize(config.norm_mean, config.norm_std)])


imagenet_dataset = datasets.ImageNet("/scratch/ssd004/datasets/imagenet256", split="train", transform=transform)

n = len(imagenet_dataset)
if args.end == -1:
args.end = n
assert args.start < n
assert args.end <= n

for i in tqdm(range(args.start, args.end)):
batch = imagenet_dataset[i]
images = batch[0]
generated_images = icgan_inference.run_inference(input_image_tensor=images.unsqueeze(0))
## save images
path = imagenet_dataset.samples[i][0]
save_images(path, generated_images, args.outdir)

if __name__ == "__main__":
main()
11 changes: 7 additions & 4 deletions icgan/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@

def get_config():
config = ml_collections.ConfigDict()
config.num_samples = 16
config.seed = 123
config.num_samples = 1
config.truncation = 0.7
config.stochastic_truncation = False
config.noise_size = 128
config.batch_size = 4
config.experiment_name = "icgan_biggan_imagenet_res256"
config.feat_ext_path = "swav_pretrained.pth.tar"
config.batch_size = 1
config.experiment_name = (
"/ssd003/projects/aieng/genssl/icgan_biggan_imagenet_res256"
)
config.feat_ext_path = "/ssd003/projects/aieng/genssl/swav_pretrained.pth.tar"
config.size = 256
config.norm_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
config.norm_std = torch.Tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
Expand Down
10 changes: 5 additions & 5 deletions icgan/icgan_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from scipy.stats import truncnorm
from torch import nn
import torchvision.transforms as transforms
import inference.utils as inference_utils
import data_utils.utils as data_utils
from icgan.inference import utils as inference_utils
from icgan.data_utils import utils as data_utils


class ICGANInference:
Expand All @@ -24,14 +24,14 @@ def replace_to_inplace_relu(self, model):
else:
self.replace_to_inplace_relu(child)

def load_icgan(self, root_="./"):
def load_icgan(self, root_=""):
root = os.path.join(root_, self.config.experiment_name)
config = torch.load("%s/%s.pth" % (root, "state_dict_best0"))["config"]
config["weights_root"] = root_
config["model_backbone"] = "biggan"
config["experiment_name"] = self.config.experiment_name
G, config = inference_utils.load_model_inference(config)
G.cuda()
G = G.cuda()
G.eval()
return G

Expand Down Expand Up @@ -62,7 +62,7 @@ def normality_loss(self, vec):
return mu2 + sigma2 - torch.log(sigma2) - 1

def load_generative_model(self):
model = self.load_icgan(root_="./")
model = self.load_icgan(root_="")
return model

def load_feature_extractor(self, feat_ext_path):
Expand Down
172 changes: 172 additions & 0 deletions img2img.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import argparse
import torch
from diffusers import StableUnCLIPImg2ImgPipeline, DDIMScheduler, DPMSolverSinglestepScheduler
# from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
import time
import os
import json



class StableGenerator(object):
def __init__(self, opt):
self.opt = opt
# model
self.model = StableUnCLIPImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1-unclip-small", torch_dtype=torch.float16, variation="fp16"
)

device = torch.device("cuda") if opt.device == "cuda" else torch.device("cpu")
self.model.to(device)
print(f"Using device: {device}")

if opt.dpm:
self.model.scheduler = DPMSolverSinglestepScheduler.from_config(self.model.scheduler.config, rescale_betas_zero_snr=True)
else:
self.model.scheduler = DDIMScheduler.from_config(self.model.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing")

print("Scheduler:", self.model.scheduler)

# self.model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
# self.model.vae.enable_xformers_memory_efficient_attention(attention_op=None)

# image size
self.height = self.opt.img_save_size
self.width = self.opt.img_save_size

# inference steps
self.num_inference_steps = self.opt.steps

# eta (0, 1)
self.eta = self.opt.ddim_eta

def generate(self, input_image, n_sample_per_image=10):
transfoem_2 = transforms.Resize(size=(self.height, self.width))
synth_images = self.model(input_image, eta=self.eta, num_images_per_prompt=n_sample_per_image, num_inference_steps=self.num_inference_steps).images
synth_images = [transfoem_2(img) for img in synth_images]
return synth_images



def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--outdir",
type=str,
nargs="?",
help="dir to write results to",
default="outputs/txt2img-samples"
)
parser.add_argument(
"--img_save_size",
type=int,
default=224,
help="image saving size"
)
parser.add_argument(
"--steps",
type=int,
default=10,
help="number of ddim sampling steps",
)
parser.add_argument(
"--dpm",
action='store_true',
help="use DPM (2) sampler",
)
parser.add_argument(
"--ddim",
action='store_true',
help="use ddim sampler",
)
parser.add_argument(
"--ddim_eta",
type=float,
default=0.0,
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
)
parser.add_argument(
"--n_samples",
type=int,
default=1,
help="how many samples to produce for each given prompt. A.k.a. batch size",
)
parser.add_argument(
"--batch_size",
type=int,
default=8,
help="batch size",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="device to use for inference",
)
parser.add_argument(
"--start",
type=int,
default=0,
help="start index",
)
parser.add_argument(
"--end",
type=int,
help="end index",
)
opt = parser.parse_args()

# s = 1
# size = 128
# color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)

if opt.outdir is not None:
os.makedirs(opt.outdir, exist_ok=True)

transform_list = [
transforms.Resize(size=(opt.img_save_size,opt.img_save_size)),
# transforms.ToTensor(),
]
transform = transforms.Compose(transform_list)

imagenet_dataset = datasets.ImageNet("/local/ssd/m2kowal/imagenet", split="train", transform = transform)
# root_folder = "/local/ssd/m2kowal/imagenet/"

Stable_generator = StableGenerator(opt)
n = len(imagenet_dataset)
assert opt.start < n
assert opt.end <= n

for i in range(opt.start, opt.end):
print(f"Batch {i}")
batch = imagenet_dataset[i]
images = batch[0]
start = time.time()
generated_images = Stable_generator.generate(images, n_sample_per_image=opt.n_samples)
end = time.time()
print(f"Time taken to generate images: {end-start} seconds")

## save images
path = imagenet_dataset.samples[i][0]
start = time.time()
save_images(path, generated_images, opt.outdir)
end = time.time()
print(f"Time taken to save images: {end-start} seconds")

def save_images(path, images, out_dir):
for j, img in enumerate(images):
out_folder = path.split("/")[-1].split(".")[0].split("_")[0] # get the class name
file_name = path.split("/")[-1].split(".")[0] # get the (class name_image number)
save_folder = os.path.join(out_dir, out_folder) # create a folder for each class

if not os.path.exists(save_folder):
os.makedirs(save_folder, exist_ok=True)

for j, img in enumerate(images):
save_file = os.path.join(save_folder, f"{file_name}_{j+1}.png")
img.save(save_file)

if __name__ == "__main__":
main()
13 changes: 7 additions & 6 deletions train_simclr.slrm
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
#!/bin/bash

#SBATCH --job-name=train_sunrgbd
#SBATCH --job-name=train_simclr
#SBATCH --partition=t4v2
#SBATCH --time=12:00:00
#SBATCH --nodes=1
#SBATCH --gres=gpu:4
#SBATCH --ntasks-per-node=4
#SBATCH --cpus-per-task=4
#SBATCH --mem-per-cpu=2G
#SBATCH --output=slurm-%N-%j.out
#SBATCH --output=slurm2-%N-%j.out

PY_ARGS=${@:1}

# load virtual environment
source /ssd003/projects/aieng/envs/genssl2/bin/activate

export NCCL_ASYNC_ERROR_HANDLING=1 # set to 1 for NCCL backend
export CUDA_LAUNCH_BLOCKING=1
# export CUDA_LAUNCH_BLOCKING=1

export MASTER_ADDR=$(hostname)
export MASTER_PORT=45679
export MASTER_PORT=45678

export PYTHONPATH="."
nvidia-smi
Expand All @@ -28,5 +28,6 @@ nvidia-smi
srun python run_simCLR.py \
--fp16-precision \
--distributed_mode \
--batch-size=4 \
--icgan_augmentation
--batch-size=16 \
--icgan_augmentation \
--subset_fraction 0.05